mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-29 01:27:06 +08:00
Merge branch 'main' into imarkov/eplb_optimizations
This commit is contained in:
commit
b57a04516e
@ -36,6 +36,11 @@ function cpu_tests() {
|
|||||||
set -e
|
set -e
|
||||||
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m"
|
||||||
|
|
||||||
|
# Run model tests
|
||||||
|
docker exec cpu-test bash -c "
|
||||||
|
set -e
|
||||||
|
pytest -x -v -s tests/models/multimodal/generation/test_whisper.py -m cpu_model"
|
||||||
|
|
||||||
# Run kernel tests
|
# Run kernel tests
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
set -e
|
set -e
|
||||||
|
|||||||
@ -1,73 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euxo pipefail
|
|
||||||
|
|
||||||
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
|
|
||||||
THRESHOLD=${1:-0.25}
|
|
||||||
NUM_Q=${2:-1319}
|
|
||||||
PORT=${3:-8030}
|
|
||||||
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
|
|
||||||
mkdir -p "${OUT_DIR}"
|
|
||||||
|
|
||||||
wait_for_server() {
|
|
||||||
local port=$1
|
|
||||||
timeout 600 bash -c '
|
|
||||||
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
|
|
||||||
sleep 1
|
|
||||||
done'
|
|
||||||
}
|
|
||||||
|
|
||||||
MODEL="deepseek-ai/DeepSeek-V2-lite"
|
|
||||||
|
|
||||||
# Set BACKENDS based on platform
|
|
||||||
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
|
|
||||||
# ROCm platform
|
|
||||||
BACKENDS=("allgather_reducescatter")
|
|
||||||
# Disable MOE padding for ROCm since it is causing eplb to fail
|
|
||||||
export VLLM_ROCM_MOE_PADDING=0
|
|
||||||
else
|
|
||||||
# Non-ROCm platform (CUDA/other)
|
|
||||||
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
|
|
||||||
fi
|
|
||||||
|
|
||||||
cleanup() {
|
|
||||||
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
|
|
||||||
kill "${SERVER_PID}" 2>/dev/null || true
|
|
||||||
for _ in {1..20}; do
|
|
||||||
kill -0 "${SERVER_PID}" 2>/dev/null || break
|
|
||||||
sleep 0.5
|
|
||||||
done
|
|
||||||
kill -9 "${SERVER_PID}" 2>/dev/null || true
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
trap cleanup EXIT
|
|
||||||
|
|
||||||
for BACK in "${BACKENDS[@]}"; do
|
|
||||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
|
||||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
|
||||||
vllm serve "$MODEL" \
|
|
||||||
--enforce-eager \
|
|
||||||
--tensor-parallel-size 2 \
|
|
||||||
--data-parallel-size 2 \
|
|
||||||
--enable-expert-parallel \
|
|
||||||
--enable-eplb \
|
|
||||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
|
||||||
--trust-remote-code \
|
|
||||||
--max-model-len 2048 \
|
|
||||||
--port $PORT &
|
|
||||||
SERVER_PID=$!
|
|
||||||
wait_for_server $PORT
|
|
||||||
|
|
||||||
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
|
|
||||||
OUT="${OUT_DIR}/${TAG}_${BACK}_async_eplb.json"
|
|
||||||
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
|
|
||||||
python3 - <<PY
|
|
||||||
import json; acc=json.load(open('${OUT}'))['accuracy']
|
|
||||||
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
|
|
||||||
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
|
|
||||||
PY
|
|
||||||
|
|
||||||
cleanup
|
|
||||||
SERVER_PID=
|
|
||||||
sleep 1
|
|
||||||
PORT=$((PORT+1))
|
|
||||||
done
|
|
||||||
@ -50,7 +50,6 @@ for BACK in "${BACKENDS[@]}"; do
|
|||||||
--data-parallel-size 2 \
|
--data-parallel-size 2 \
|
||||||
--enable-expert-parallel \
|
--enable-expert-parallel \
|
||||||
--enable-eplb \
|
--enable-eplb \
|
||||||
--eplb-config '{"window_size":200,"step_interval":600}' \
|
|
||||||
--trust-remote-code \
|
--trust-remote-code \
|
||||||
--max-model-len 2048 \
|
--max-model-len 2048 \
|
||||||
--port $PORT &
|
--port $PORT &
|
||||||
|
|||||||
@ -1,74 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
set -euxo pipefail
|
|
||||||
|
|
||||||
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
|
|
||||||
THRESHOLD=${1:-0.25}
|
|
||||||
NUM_Q=${2:-1319}
|
|
||||||
PORT=${3:-8040}
|
|
||||||
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
|
|
||||||
mkdir -p "${OUT_DIR}"
|
|
||||||
|
|
||||||
wait_for_server() {
|
|
||||||
local port=$1
|
|
||||||
timeout 600 bash -c '
|
|
||||||
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
|
|
||||||
sleep 1
|
|
||||||
done'
|
|
||||||
}
|
|
||||||
|
|
||||||
MODEL="Qwen/Qwen3-Next-80B-A3B-Instruct"
|
|
||||||
|
|
||||||
# Set BACKENDS based on platform
|
|
||||||
if command -v rocm-smi &> /dev/null || [[ -d /opt/rocm ]] || [[ -n "${ROCM_PATH:-}" ]]; then
|
|
||||||
# ROCm platform
|
|
||||||
BACKENDS=("allgather_reducescatter")
|
|
||||||
# Disable MOE padding for ROCm since it is causing eplb to fail
|
|
||||||
export VLLM_ROCM_MOE_PADDING=0
|
|
||||||
else
|
|
||||||
# Non-ROCm platform (CUDA/other)
|
|
||||||
BACKENDS=("deepep_high_throughput" "deepep_low_latency")
|
|
||||||
fi
|
|
||||||
|
|
||||||
cleanup() {
|
|
||||||
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
|
|
||||||
kill "${SERVER_PID}" 2>/dev/null || true
|
|
||||||
for _ in {1..20}; do
|
|
||||||
kill -0 "${SERVER_PID}" 2>/dev/null || break
|
|
||||||
sleep 0.5
|
|
||||||
done
|
|
||||||
kill -9 "${SERVER_PID}" 2>/dev/null || true
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
trap cleanup EXIT
|
|
||||||
|
|
||||||
for BACK in "${BACKENDS[@]}"; do
|
|
||||||
VLLM_DEEP_GEMM_WARMUP=skip \
|
|
||||||
VLLM_ALL2ALL_BACKEND=$BACK \
|
|
||||||
vllm serve "$MODEL" \
|
|
||||||
--enforce-eager \
|
|
||||||
--tensor-parallel-size 4 \
|
|
||||||
--enable-expert-parallel \
|
|
||||||
--enable-eplb \
|
|
||||||
--eplb-config '{"window_size":200,"step_interval":600,"use_async":true}' \
|
|
||||||
--speculative-config '{"method":"qwen3_next_mtp","num_speculative_tokens":1}' \
|
|
||||||
--trust-remote-code \
|
|
||||||
--max-model-len 2048 \
|
|
||||||
--gpu-memory-utilization 0.9 \
|
|
||||||
--port $PORT &
|
|
||||||
SERVER_PID=$!
|
|
||||||
wait_for_server $PORT
|
|
||||||
|
|
||||||
TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
|
|
||||||
OUT="${OUT_DIR}/${TAG}_${BACK}.json"
|
|
||||||
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port $PORT --num-questions ${NUM_Q} --save-results ${OUT}
|
|
||||||
python3 - <<PY
|
|
||||||
import json; acc=json.load(open('${OUT}'))['accuracy']
|
|
||||||
print(f"${MODEL} ${BACK}: accuracy {acc:.3f}")
|
|
||||||
assert acc >= ${THRESHOLD}, f"${MODEL} ${BACK} accuracy {acc}"
|
|
||||||
PY
|
|
||||||
|
|
||||||
cleanup
|
|
||||||
SERVER_PID=
|
|
||||||
sleep 1
|
|
||||||
PORT=$((PORT+1))
|
|
||||||
done
|
|
||||||
@ -1380,21 +1380,3 @@ steps:
|
|||||||
working_dir: "/vllm-workspace"
|
working_dir: "/vllm-workspace"
|
||||||
commands:
|
commands:
|
||||||
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1
|
||||||
|
|
||||||
- label: DeepSeek V2-Lite Async EPLB Accuracy
|
|
||||||
timeout_in_minutes: 60
|
|
||||||
gpu: h100
|
|
||||||
optional: true
|
|
||||||
num_gpus: 4
|
|
||||||
working_dir: "/vllm-workspace"
|
|
||||||
commands:
|
|
||||||
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_ep_async_eplb.sh 0.25 1319 8030
|
|
||||||
|
|
||||||
- label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy
|
|
||||||
timeout_in_minutes: 60
|
|
||||||
gpu: h100
|
|
||||||
optional: true
|
|
||||||
num_gpus: 4
|
|
||||||
working_dir: "/vllm-workspace"
|
|
||||||
commands:
|
|
||||||
- bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040
|
|
||||||
|
|||||||
@ -32,12 +32,11 @@ def benchmark_propose(args):
|
|||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model="facebook/opt-125m",
|
model="facebook/opt-125m",
|
||||||
task="generate",
|
|
||||||
max_model_len=args.num_token + args.num_spec_token,
|
max_model_len=args.num_token + args.num_spec_token,
|
||||||
tokenizer="facebook/opt-125m",
|
tokenizer="facebook/opt-125m",
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
dtype="auto",
|
dtype="auto",
|
||||||
seed=None,
|
seed=0,
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
)
|
)
|
||||||
proposer = NgramProposer(
|
proposer = NgramProposer(
|
||||||
|
|||||||
@ -574,7 +574,7 @@ async def benchmark(
|
|||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
"{:<40} {:<10.2f}".format(
|
"{:<40} {:<10.2f}".format(
|
||||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
"Total token throughput (tok/s):", metrics.total_token_throughput
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
|
||||||
|
in MLA (Multi-head Latent Attention) prefill.
|
||||||
|
|
||||||
|
This validates that the optimization from commit 8d4142bd is beneficial across
|
||||||
|
various batch sizes, not just the originally tested batch size of 32768.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# DeepSeek-V3 MLA dimensions
|
||||||
|
NUM_HEADS = 128
|
||||||
|
QK_NOPE_HEAD_DIM = 128
|
||||||
|
PE_DIM = 64
|
||||||
|
|
||||||
|
|
||||||
|
def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Original torch.cat approach with expand."""
|
||||||
|
return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Optimized direct copy approach (avoids expand + cat overhead)."""
|
||||||
|
k = torch.empty(
|
||||||
|
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
|
||||||
|
dtype=k_nope.dtype,
|
||||||
|
device=k_nope.device,
|
||||||
|
)
|
||||||
|
k[..., : k_nope.shape[-1]] = k_nope
|
||||||
|
k[..., k_nope.shape[-1] :] = k_pe
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_method(
|
||||||
|
method: Callable,
|
||||||
|
k_nope: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
num_warmup: int = 10,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> float:
|
||||||
|
"""Benchmark a concatenation method and return mean latency in ms."""
|
||||||
|
# Warmup
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
_ = method(k_nope, k_pe)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Benchmark
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
_ = method(k_nope, k_pe)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.perf_counter()
|
||||||
|
|
||||||
|
return (end - start) / num_iters * 1000 # Convert to ms
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_benchmark(dtype: torch.dtype, dtype_name: str):
|
||||||
|
"""Run benchmark for a specific dtype."""
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
# Batch sizes to test (powers of 2 from 32 to 65536)
|
||||||
|
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation")
|
||||||
|
print("=" * 80)
|
||||||
|
print(
|
||||||
|
f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], "
|
||||||
|
f"k_pe=[B, 1, {PE_DIM}]"
|
||||||
|
)
|
||||||
|
print(f"dtype: {dtype_name}")
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | "
|
||||||
|
f"{'Speedup':>8} | {'Reduction':>10}"
|
||||||
|
)
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
# Create input tensors (generate in float32 then convert for FP8 compatibility)
|
||||||
|
k_nope = torch.randn(
|
||||||
|
batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda"
|
||||||
|
).to(dtype)
|
||||||
|
k_pe = torch.randn(
|
||||||
|
batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda"
|
||||||
|
).to(dtype)
|
||||||
|
|
||||||
|
# Benchmark both methods
|
||||||
|
cat_time = benchmark_method(cat_method, k_nope, k_pe)
|
||||||
|
direct_time = benchmark_method(direct_copy_method, k_nope, k_pe)
|
||||||
|
|
||||||
|
speedup = cat_time / direct_time
|
||||||
|
reduction = (1 - direct_time / cat_time) * 100
|
||||||
|
|
||||||
|
results.append((batch_size, cat_time, direct_time, speedup, reduction))
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | "
|
||||||
|
f"{speedup:>7.2f}x | {reduction:>9.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Summary statistics
|
||||||
|
speedups = [r[3] for r in results]
|
||||||
|
print("\nSpeedup summary:")
|
||||||
|
print(f" Min: {min(speedups):.2f}x")
|
||||||
|
print(f" Max: {max(speedups):.2f}x")
|
||||||
|
print(f" Mean: {sum(speedups) / len(speedups):.2f}x")
|
||||||
|
|
||||||
|
# Find crossover point
|
||||||
|
crossover_batch = None
|
||||||
|
for batch_size, _, _, speedup, _ in results:
|
||||||
|
if speedup >= 1.0:
|
||||||
|
crossover_batch = batch_size
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\nConclusion:")
|
||||||
|
if crossover_batch:
|
||||||
|
print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}")
|
||||||
|
# Filter for large batches (>= 512 which is typical for prefill)
|
||||||
|
large_batch_speedups = [r[3] for r in results if r[0] >= 512]
|
||||||
|
if large_batch_speedups:
|
||||||
|
avg_large = sum(large_batch_speedups) / len(large_batch_speedups)
|
||||||
|
print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x")
|
||||||
|
print(" - MLA prefill typically uses large batches, so optimization is effective")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
# Test bfloat16
|
||||||
|
print("\n")
|
||||||
|
run_benchmark(torch.bfloat16, "bfloat16")
|
||||||
|
|
||||||
|
# Test float8_e4m3fn
|
||||||
|
print("\n")
|
||||||
|
run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -251,17 +251,6 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Build ACL with CMake
|
# Build ACL with CMake
|
||||||
set(ARM_COMPUTE_BUILD_SHARED_LIB "OFF")
|
|
||||||
set(CMAKE_BUILD_TYPE "Release")
|
|
||||||
set(ARM_COMPUTE_ARCH "armv8.2-a")
|
|
||||||
set(ARM_COMPUTE_ENABLE_ASSERTS "OFF")
|
|
||||||
set(ARM_COMPUTE_ENABLE_CPPTHREADS "OFF")
|
|
||||||
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
|
|
||||||
set(ARM_COMPUTE_ENABLE_OPENMP "ON")
|
|
||||||
set(ARM_COMPUTE_ENABLE_WERROR "OFF")
|
|
||||||
set(ARM_COMPUTE_BUILD_EXAMPLES "OFF")
|
|
||||||
set(ARM_COMPUTE_BUILD_TESTING "OFF")
|
|
||||||
|
|
||||||
set(_cmake_config_cmd
|
set(_cmake_config_cmd
|
||||||
${CMAKE_COMMAND} -G Ninja -B build
|
${CMAKE_COMMAND} -G Ninja -B build
|
||||||
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF
|
-DARM_COMPUTE_BUILD_SHARED_LIB=OFF
|
||||||
|
|||||||
@ -117,7 +117,6 @@ torch::Tensor get_scheduler_metadata(
|
|||||||
input.casual = casual;
|
input.casual = casual;
|
||||||
input.isa = isa;
|
input.isa = isa;
|
||||||
input.enable_kv_split = enable_kv_split;
|
input.enable_kv_split = enable_kv_split;
|
||||||
TORCH_CHECK(casual, "Only supports casual mask for now.");
|
|
||||||
|
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
VLLM_DISPATCH_FLOATING_TYPES(dtype, "get_scheduler_metadata", [&]() {
|
||||||
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
CPU_ATTN_DISPATCH_CASE_HEADDIM(head_dim, [&] {
|
||||||
|
|||||||
@ -186,7 +186,7 @@ struct AttentionMetadata {
|
|||||||
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
|
// - Intermediate outputs: q_tile_size * head_dim * output_buffer_elem_size + 2
|
||||||
// * q_tile_size * 4, partial output, max + sum (float)
|
// * q_tile_size * 4, partial output, max + sum (float)
|
||||||
// Reduction scratchpad contains:
|
// Reduction scratchpad contains:
|
||||||
// - flags: bool array to indicate wether the split is finished
|
// - flags: bool array to indicate whether the split is finished
|
||||||
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
|
// - outputs: split_num * q_tile_size * head_dim * output_buffer_elem_size
|
||||||
// - max, sum: 2 * split_num * q_tile_size * 4
|
// - max, sum: 2 * split_num * q_tile_size * 4
|
||||||
class AttentionScratchPad {
|
class AttentionScratchPad {
|
||||||
|
|||||||
@ -617,7 +617,7 @@ struct MacheteCollectiveMma {
|
|||||||
|
|
||||||
// Same as upstream, should be kept the same when possible, not formatted for
|
// Same as upstream, should be kept the same when possible, not formatted for
|
||||||
// easier comparison
|
// easier comparison
|
||||||
// with `SwapAB ? N : M -> M` since we dont support SwapAB
|
// with `SwapAB ? N : M -> M` since we don't support SwapAB
|
||||||
// clang-format off
|
// clang-format off
|
||||||
template<class ProblemShape>
|
template<class ProblemShape>
|
||||||
static bool
|
static bool
|
||||||
|
|||||||
@ -1241,33 +1241,16 @@ __global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx,
|
|||||||
}
|
}
|
||||||
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
#endif // defined(__HIP__GFX9__) TODO: Add NAVI support
|
||||||
|
|
||||||
|
// Find the min val of div2 that doesn't increase N/(div1*div2)
|
||||||
int mindiv(int N, int div1, int div2) {
|
int mindiv(int N, int div1, int div2) {
|
||||||
int nPrRnd = div1 * div2;
|
int nPrRnd = div1 * div2;
|
||||||
int rnds0 = N / nPrRnd;
|
int rnds[13];
|
||||||
nPrRnd -= div1 * 3;
|
for (int i = 0; i < 13; i++) {
|
||||||
int rnds3 = N / nPrRnd;
|
rnds[i] = (N + nPrRnd - 1) / nPrRnd;
|
||||||
nPrRnd -= div1;
|
nPrRnd -= div1;
|
||||||
int rnds4 = N / nPrRnd;
|
}
|
||||||
nPrRnd -= div1;
|
for (int i = 12; i >= 0; i--)
|
||||||
int rnds5 = N / nPrRnd;
|
if (rnds[0] == rnds[i]) return (div2 - i);
|
||||||
nPrRnd -= div1;
|
|
||||||
int rnds6 = N / nPrRnd;
|
|
||||||
nPrRnd -= div1;
|
|
||||||
int rnds7 = N / nPrRnd;
|
|
||||||
nPrRnd -= div1;
|
|
||||||
int rnds8 = N / nPrRnd;
|
|
||||||
nPrRnd -= div1;
|
|
||||||
int rnds9 = N / nPrRnd;
|
|
||||||
nPrRnd -= div1;
|
|
||||||
int rtn = div2;
|
|
||||||
if (rnds0 == rnds3) rtn = div2 - 3;
|
|
||||||
if (rnds0 == rnds4) rtn = div2 - 4;
|
|
||||||
if (rnds0 == rnds5) rtn = div2 - 5;
|
|
||||||
if (rnds0 == rnds6) rtn = div2 - 6;
|
|
||||||
if (rnds0 == rnds7) rtn = div2 - 7;
|
|
||||||
if (rnds0 == rnds8) rtn = div2 - 8;
|
|
||||||
if (rnds0 == rnds9) rtn = div2 - 9;
|
|
||||||
return rtn;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
||||||
@ -1300,26 +1283,37 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
const int max_lds_len = get_lds_size() / 2;
|
const int max_lds_len = get_lds_size() / 2;
|
||||||
|
|
||||||
#define WVSPLITK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \
|
#define WVSPLITK(_YTILE, _UNRL, _N) \
|
||||||
_N) \
|
{ \
|
||||||
{ \
|
dim3 block(64, 16); \
|
||||||
dim3 block(64, _WvPrGrp); \
|
int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \
|
||||||
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILEs == 0)) { \
|
if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \
|
||||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEs, _WvPrGrp); \
|
wvSplitK_hf_sml_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
wvSplitK_hf_sml_<fptype, 64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \
|
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
biasf4, c, __wvPrGrp, CuCount); \
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
else if (K_in * N_in <= max_lds_len * 1.2) \
|
||||||
} else if (K_in * N_in <= max_lds_len * 1.2) { \
|
wvSplitK_hf_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEm, _WvPrGrp); \
|
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||||
wvSplitK_hf_<fptype, 64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \
|
biasf4, c, __wvPrGrp, CuCount); \
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
else \
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
wvSplitK_hf_big_<fptype, 64, _YTILE, 16, 8, _UNRL, _N> \
|
||||||
} else { \
|
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
||||||
int __wvPrGrp = mindiv(M_in, CuCount * _YTILEb, _WvPrGrp); \
|
biasf4, c, __wvPrGrp, CuCount); \
|
||||||
wvSplitK_hf_big_<fptype, 64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \
|
}
|
||||||
<<<grid, block, 0, stream>>>(K_in, M_in, Bx_in, By_in, af4, bf4, \
|
|
||||||
biasf4, c, __wvPrGrp, CuCount); \
|
#define WVSPLIT_TILE(_sYT, __N) \
|
||||||
} \
|
{ \
|
||||||
|
bool fit_lds = (K_in * N_in <= max_lds_len); \
|
||||||
|
if (_sYT <= 1) \
|
||||||
|
WVSPLITK(1, 4, __N) \
|
||||||
|
else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \
|
||||||
|
WVSPLITK(2, 2, __N) \
|
||||||
|
else if (_sYT <= 4 * 3) \
|
||||||
|
WVSPLITK(3, 2, __N) \
|
||||||
|
else if (__N == 4) \
|
||||||
|
WVSPLITK(4, 1, __N) \
|
||||||
|
else \
|
||||||
|
WVSPLITK(4, 2, __N) \
|
||||||
}
|
}
|
||||||
|
|
||||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitK", [&] {
|
||||||
@ -1331,18 +1325,23 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b,
|
|||||||
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
? reinterpret_cast<const fptype*>(in_bias->data_ptr())
|
||||||
: nullptr;
|
: nullptr;
|
||||||
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
fptype* c = reinterpret_cast<fptype*>(out_c.data_ptr());
|
||||||
|
|
||||||
|
// first shoot for biggest tile-size that keeps all simd busy,
|
||||||
|
// then cut the active waves to balance their distribution...
|
||||||
|
int sYT = (M_in + CuCount * 4 - 1) / (CuCount * 4);
|
||||||
|
|
||||||
switch (N_in) {
|
switch (N_in) {
|
||||||
case 1:
|
case 1:
|
||||||
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 1)
|
WVSPLIT_TILE(sYT, 1)
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
WVSPLITK(16, 2, 2, 2, 2, 2, 2, 2)
|
WVSPLIT_TILE(sYT, 2)
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 3)
|
WVSPLIT_TILE(sYT, 3)
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
WVSPLITK(16, 4, 7, 7, 1, 1, 1, 4)
|
WVSPLIT_TILE(sYT, 4)
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
|||||||
@ -84,7 +84,7 @@ Total input tokens: 1369
|
|||||||
Total generated tokens: 2212
|
Total generated tokens: 2212
|
||||||
Request throughput (req/s): 1.73
|
Request throughput (req/s): 1.73
|
||||||
Output token throughput (tok/s): 382.89
|
Output token throughput (tok/s): 382.89
|
||||||
Total Token throughput (tok/s): 619.85
|
Total token throughput (tok/s): 619.85
|
||||||
---------------Time to First Token----------------
|
---------------Time to First Token----------------
|
||||||
Mean TTFT (ms): 71.54
|
Mean TTFT (ms): 71.54
|
||||||
Median TTFT (ms): 73.88
|
Median TTFT (ms): 73.88
|
||||||
|
|||||||
@ -21,30 +21,20 @@ The mental model is that server-level metrics help explain the values of request
|
|||||||
|
|
||||||
### v1 Metrics
|
### v1 Metrics
|
||||||
|
|
||||||
In v1, the following metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix:
|
In v1, an extensive set of metrics are exposed via a Prometheus-compatible `/metrics` endpoint using the `vllm:` prefix, for example:
|
||||||
|
|
||||||
- `vllm:num_requests_running` (Gauge) - Number of requests currently running.
|
- `vllm:num_requests_running` (Gauge) - Number of requests currently running.
|
||||||
- `vllm:num_requests_waiting` (Gauge) - Number of requests currently waiting.
|
|
||||||
- `vllm:kv_cache_usage_perc` (Gauge) - Fraction of used KV cache blocks (0–1).
|
- `vllm:kv_cache_usage_perc` (Gauge) - Fraction of used KV cache blocks (0–1).
|
||||||
- `vllm:prefix_cache_queries` (Counter) - Number of prefix cache queries.
|
- `vllm:prefix_cache_queries` (Counter) - Number of prefix cache queries.
|
||||||
- `vllm:prefix_cache_hits` (Counter) - Number of prefix cache hits.
|
- `vllm:prefix_cache_hits` (Counter) - Number of prefix cache hits.
|
||||||
- `vllm:mm_cache_queries` (Counter) - (For multimodal models) Number of multimodal cache queries.
|
|
||||||
- `vllm:mm_cache_hits` (Counter) - (For multimodal models) Number of multimodal cache hits.
|
|
||||||
- `vllm:num_preemptions_total` (Counter) - Number of preemptions.
|
|
||||||
- `vllm:prompt_tokens_total` (Counter) - Total number of prompt tokens processed.
|
- `vllm:prompt_tokens_total` (Counter) - Total number of prompt tokens processed.
|
||||||
- `vllm:generation_tokens_total` (Counter) - Total number of generated tokens.
|
- `vllm:generation_tokens_total` (Counter) - Total number of generated tokens.
|
||||||
- `vllm:iteration_tokens_total` (Histogram) - Histogram of tokens processed in each engine step.
|
|
||||||
- `vllm:cache_config_info` (Gauge) - Information about the cache configuration.
|
|
||||||
- `vllm:request_success_total` (Counter) - Number of finished requests (by finish reason).
|
- `vllm:request_success_total` (Counter) - Number of finished requests (by finish reason).
|
||||||
- `vllm:request_prompt_tokens` (Histogram) - Histogram of input prompt token counts.
|
- `vllm:request_prompt_tokens` (Histogram) - Histogram of input prompt token counts.
|
||||||
- `vllm:request_generation_tokens` (Histogram) - Histogram of generation token counts.
|
- `vllm:request_generation_tokens` (Histogram) - Histogram of generation token counts.
|
||||||
- `vllm:request_params_n` (Histogram) - Histogram of request parameter n.
|
|
||||||
- `vllm:request_params_max_tokens` - (Histogram) - Histogram of max_tokens parameter in requests.
|
|
||||||
- `vllm:time_to_first_token_seconds` (Histogram) - Time to first token (TTFT).
|
- `vllm:time_to_first_token_seconds` (Histogram) - Time to first token (TTFT).
|
||||||
- `vllm:inter_token_latency_seconds` (Histogram) - Inter-token latency.
|
- `vllm:inter_token_latency_seconds` (Histogram) - Inter-token latency.
|
||||||
- `vllm:e2e_request_latency_seconds` (Histogram) - End-to-end request latency.
|
- `vllm:e2e_request_latency_seconds` (Histogram) - End-to-end request latency.
|
||||||
- `vllm:request_queue_time_seconds` (Histogram) - Time spent in the queue.
|
|
||||||
- `vllm:request_inference_time_seconds` (Histogram) - Request inference time.
|
|
||||||
- `vllm:request_prefill_time_seconds` (Histogram) - Request prefill time.
|
- `vllm:request_prefill_time_seconds` (Histogram) - Request prefill time.
|
||||||
- `vllm:request_decode_time_seconds` (Histogram) - Request decode time.
|
- `vllm:request_decode_time_seconds` (Histogram) - Request decode time.
|
||||||
|
|
||||||
|
|||||||
@ -152,5 +152,5 @@ The interface for the model/module may change during vLLM's development. If you
|
|||||||
## Deprecation announcement
|
## Deprecation announcement
|
||||||
|
|
||||||
!!! warning "Deprecations"
|
!!! warning "Deprecations"
|
||||||
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It will be removed in v0.13.0 or v1.0.0.
|
- `use_v1` parameter in `Platform.get_attn_backend_cls` is deprecated. It has been removed in v0.13.0.
|
||||||
- `_Backend` in `vllm.attention` is deprecated. It will be removed in v0.13.0 or v1.0.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
|
- `_Backend` in `vllm.attention` is deprecated. It has been removed in v0.13.0. Please use `vllm.attention.backends.registry.register_backend` to add new attention backend to `AttentionBackendEnum` instead.
|
||||||
|
|||||||
@ -22,7 +22,7 @@ python tools/install_nixl_from_source_ubuntu.py
|
|||||||
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
|
NixlConnector uses NIXL library for underlying communication, which supports multiple transport backends. UCX (Unified Communication X) is the primary default transport library used by NIXL. Configure transport environment variables:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Example UCX configuration, adjust according to your enviroment
|
# Example UCX configuration, adjust according to your environment
|
||||||
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
|
export UCX_TLS=all # or specify specific transports like "rc,ud,sm,^cuda_ipc" ..etc
|
||||||
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
|
export UCX_NET_DEVICES=all # or specify network devices like "mlx5_0:1,mlx5_1:1"
|
||||||
```
|
```
|
||||||
|
|||||||
@ -61,7 +61,7 @@ Now let´s see an example for each of the cases, starting with the `choice`, as
|
|||||||
print(completion.choices[0].message.content)
|
print(completion.choices[0].message.content)
|
||||||
```
|
```
|
||||||
|
|
||||||
The next example shows how to use the `regex`. The idea is to generate an email address, given a simple regex template:
|
The next example shows how to use the `regex`. The supported regex syntax depends on the structured output backend. For example, `xgrammar`, `guidance`, and `outlines` use Rust-style regex, while `lm-format-enforcer` uses Python's `re` module. The idea is to generate an email address, given a simple regex template:
|
||||||
|
|
||||||
??? code
|
??? code
|
||||||
|
|
||||||
|
|||||||
@ -26,3 +26,4 @@ The backends below live **outside** the main `vllm` repository and follow the
|
|||||||
| Rebellions ATOM / REBEL NPU | `vllm-rbln` | <https://github.com/rebellions-sw/vllm-rbln> |
|
| Rebellions ATOM / REBEL NPU | `vllm-rbln` | <https://github.com/rebellions-sw/vllm-rbln> |
|
||||||
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
|
| IBM Spyre AIU | `vllm-spyre` | <https://github.com/vllm-project/vllm-spyre> |
|
||||||
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
|
| Cambricon MLU | `vllm-mlu` | <https://github.com/Cambricon/vllm-mlu> |
|
||||||
|
| Baidu Kunlun XPU | N/A, install from source | <https://github.com/baidu/vLLM-Kunlun> |
|
||||||
|
|||||||
149
docs/mkdocs/hooks/generate_metrics.py
Normal file
149
docs/mkdocs/hooks/generate_metrics.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import ast
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
logger = logging.getLogger("mkdocs")
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parent.parent.parent.parent
|
||||||
|
DOCS_DIR = ROOT_DIR / "docs"
|
||||||
|
GENERATED_METRICS_DIR = DOCS_DIR / "generated" / "metrics"
|
||||||
|
|
||||||
|
# Files to scan for metric definitions - each will generate a separate table
|
||||||
|
METRIC_SOURCE_FILES = [
|
||||||
|
{"path": "vllm/v1/metrics/loggers.py", "output": "general.md"},
|
||||||
|
{
|
||||||
|
"path": "vllm/v1/spec_decode/metrics.py",
|
||||||
|
"output": "spec_decode.md",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py",
|
||||||
|
"output": "nixl_connector.md",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MetricExtractor(ast.NodeVisitor):
|
||||||
|
"""AST visitor to extract metric definitions."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.metrics: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
def visit_Call(self, node: ast.Call) -> None:
|
||||||
|
"""Visit function calls to find metric class instantiations."""
|
||||||
|
metric_type = self._get_metric_type(node)
|
||||||
|
if metric_type:
|
||||||
|
name = self._extract_kwarg(node, "name")
|
||||||
|
documentation = self._extract_kwarg(node, "documentation")
|
||||||
|
|
||||||
|
if name:
|
||||||
|
self.metrics.append(
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"type": metric_type,
|
||||||
|
"documentation": documentation or "",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
def _get_metric_type(self, node: ast.Call) -> str | None:
|
||||||
|
"""Determine if this call creates a metric and return its type."""
|
||||||
|
metric_type_map = {
|
||||||
|
"_gauge_cls": "gauge",
|
||||||
|
"_counter_cls": "counter",
|
||||||
|
"_histogram_cls": "histogram",
|
||||||
|
}
|
||||||
|
if isinstance(node.func, ast.Attribute):
|
||||||
|
return metric_type_map.get(node.func.attr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_kwarg(self, node: ast.Call, key: str) -> str | None:
|
||||||
|
"""Extract a keyword argument value from a function call."""
|
||||||
|
for keyword in node.keywords:
|
||||||
|
if keyword.arg == key:
|
||||||
|
return self._get_string_value(keyword.value)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_string_value(self, node: ast.AST) -> str | None:
|
||||||
|
"""Extract string value from an AST node."""
|
||||||
|
if isinstance(node, ast.Constant):
|
||||||
|
return str(node.value) if node.value is not None else None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_metrics_from_file(filepath: Path) -> list[dict[str, str]]:
|
||||||
|
"""Parse a Python file and extract all metric definitions."""
|
||||||
|
try:
|
||||||
|
with open(filepath, encoding="utf-8") as f:
|
||||||
|
source = f.read()
|
||||||
|
|
||||||
|
tree = ast.parse(source, filename=str(filepath))
|
||||||
|
extractor = MetricExtractor()
|
||||||
|
extractor.visit(tree)
|
||||||
|
return extractor.metrics
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to parse {filepath}: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def generate_markdown_table(metrics: list[dict[str, str]]) -> str:
|
||||||
|
"""Generate a markdown table from extracted metrics."""
|
||||||
|
if not metrics:
|
||||||
|
return "No metrics found.\n"
|
||||||
|
|
||||||
|
# Sort by type, then by name
|
||||||
|
metrics_sorted = sorted(metrics, key=lambda m: (m["type"], m["name"]))
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
lines.append("| Metric Name | Type | Description |")
|
||||||
|
lines.append("|-------------|------|-------------|")
|
||||||
|
|
||||||
|
for metric in metrics_sorted:
|
||||||
|
name = metric["name"]
|
||||||
|
metric_type = metric["type"].capitalize()
|
||||||
|
doc = metric["documentation"].replace("\n", " ").strip()
|
||||||
|
lines.append(f"| `{name}` | {metric_type} | {doc} |")
|
||||||
|
|
||||||
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||||
|
"""Generate metrics documentation tables from source files."""
|
||||||
|
logger.info("Generating metrics documentation")
|
||||||
|
|
||||||
|
# Create generated directory if it doesn't exist
|
||||||
|
GENERATED_METRICS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
total_metrics = 0
|
||||||
|
for source_config in METRIC_SOURCE_FILES:
|
||||||
|
source_path = source_config["path"]
|
||||||
|
output_file = source_config["output"]
|
||||||
|
|
||||||
|
filepath = ROOT_DIR / source_path
|
||||||
|
if not filepath.exists():
|
||||||
|
raise FileNotFoundError(f"Metrics source file not found: {filepath}")
|
||||||
|
|
||||||
|
logger.debug("Extracting metrics from: %s", source_path)
|
||||||
|
metrics = extract_metrics_from_file(filepath)
|
||||||
|
logger.debug("Found %d metrics in %s", len(metrics), source_path)
|
||||||
|
|
||||||
|
# Generate and write the markdown table for this source
|
||||||
|
table_content = generate_markdown_table(metrics)
|
||||||
|
output_path = GENERATED_METRICS_DIR / output_file
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(table_content)
|
||||||
|
|
||||||
|
total_metrics += len(metrics)
|
||||||
|
logger.info(
|
||||||
|
"Generated metrics table: %s (%d metrics)",
|
||||||
|
output_path.relative_to(ROOT_DIR),
|
||||||
|
len(metrics),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Total metrics generated: %d across %d files",
|
||||||
|
total_metrics,
|
||||||
|
len(METRIC_SOURCE_FILES),
|
||||||
|
)
|
||||||
@ -316,10 +316,13 @@ We have split the `encode` task into two more specific token-wise tasks: `token_
|
|||||||
|
|
||||||
### Remove softmax from PoolingParams
|
### Remove softmax from PoolingParams
|
||||||
|
|
||||||
We are going to remove `softmax` and `activation` from `PoolingParams`. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
|
We are going to remove `softmax` and `activation` from `PoolingParams` in v0.15. Instead, use `use_activation`, since we allow `classify` and `token_classify` to use any activation function.
|
||||||
|
|
||||||
### as_reward_model
|
### as_reward_model
|
||||||
|
|
||||||
|
!!! warning
|
||||||
|
We are going to remove `--convert reward` in v0.15, use `--convert embed` instead.
|
||||||
|
|
||||||
Pooling models now default support all pooling, you can use it without any settings.
|
Pooling models now default support all pooling, you can use it without any settings.
|
||||||
|
|
||||||
- Extracting hidden states prefers using `token_embed` task.
|
- Extracting hidden states prefers using `token_embed` task.
|
||||||
|
|||||||
@ -24,7 +24,7 @@ There are two distinct modes supported for online deployments - self-contained w
|
|||||||
|
|
||||||
vLLM supports "self-contained" data parallel deployments that expose a single API endpoint.
|
vLLM supports "self-contained" data parallel deployments that expose a single API endpoint.
|
||||||
|
|
||||||
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs.
|
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs. When sizing DP deployments, remember that `--max-num-seqs` applies per DP rank.
|
||||||
|
|
||||||
Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks.
|
Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks.
|
||||||
|
|
||||||
@ -80,6 +80,18 @@ When deploying large DP sizes using this method, the API server process can beco
|
|||||||

|

|
||||||
</figure>
|
</figure>
|
||||||
|
|
||||||
|
## Hybrid Load Balancing
|
||||||
|
|
||||||
|
Hybrid load balancing sits between the internal and external approaches. Each node runs its own API server(s) that only queue requests to the data-parallel engines colocated on that node. An upstream load balancer (for example, an ingress controller or traffic router) spreads user requests across those per-node endpoints.
|
||||||
|
|
||||||
|
Enable this mode with `--data-parallel-hybrid-lb` while still launching every node with the global data-parallel size. The key differences from internal load balancing are:
|
||||||
|
|
||||||
|
- You must provide `--data-parallel-size-local` and `--data-parallel-start-rank` so each node knows which ranks it owns.
|
||||||
|
- Not compatible with `--headless` since every node exposes an API endpoint.
|
||||||
|
- Scale `--api-server-count` per node based on the number of local ranks
|
||||||
|
|
||||||
|
In this configuration, each node keeps scheduling decisions local, which reduces cross-node traffic and avoids single node bottlenecks at larger DP sizes.
|
||||||
|
|
||||||
## External Load Balancing
|
## External Load Balancing
|
||||||
|
|
||||||
For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally.
|
For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally.
|
||||||
|
|||||||
@ -40,10 +40,12 @@ EP_SIZE = TP_SIZE × DP_SIZE
|
|||||||
|
|
||||||
Where:
|
Where:
|
||||||
|
|
||||||
- `TP_SIZE`: Tensor parallel size (always 1 for now)
|
- `TP_SIZE`: Tensor parallel size
|
||||||
- `DP_SIZE`: Data parallel size
|
- `DP_SIZE`: Data parallel size
|
||||||
- `EP_SIZE`: Expert parallel size (computed automatically)
|
- `EP_SIZE`: Expert parallel size (computed automatically)
|
||||||
|
|
||||||
|
When EP is enabled, MoE layers use expert parallelism instead of tensor parallelism, while attention layers continue to use tensor parallelism if `TP_SIZE > 1`.
|
||||||
|
|
||||||
### Example Command
|
### Example Command
|
||||||
|
|
||||||
The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
|
The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parallel, 8-way (attention) data parallel, and 8-way expert parallel. The attention weights are replicated across all GPUs, while the expert weights are split across GPUs. It will work on a H200 (or H20) node with 8 GPUs. For H100, you can try to serve a smaller model or refer to the multi-node deployment section.
|
||||||
@ -81,7 +83,7 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
|||||||
--data-parallel-size-local 8 \ # Local DP size on this node (8 GPUs per node)
|
--data-parallel-size-local 8 \ # Local DP size on this node (8 GPUs per node)
|
||||||
--data-parallel-address 192.168.1.100 \ # Replace with actual IP of Node 1
|
--data-parallel-address 192.168.1.100 \ # Replace with actual IP of Node 1
|
||||||
--data-parallel-rpc-port 13345 \ # RPC communication port, can be any port as long as reachable by all nodes
|
--data-parallel-rpc-port 13345 \ # RPC communication port, can be any port as long as reachable by all nodes
|
||||||
--api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended)
|
--api-server-count=8 # Number of API servers for load handling (scaling this out to # local ranks is recommended)
|
||||||
|
|
||||||
# Node 2 (Secondary - headless mode, no API server)
|
# Node 2 (Secondary - headless mode, no API server)
|
||||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||||
@ -119,9 +121,6 @@ While MoE models are typically trained so that each expert receives a similar nu
|
|||||||
|
|
||||||
Enable EPLB with the `--enable-eplb` flag.
|
Enable EPLB with the `--enable-eplb` flag.
|
||||||
|
|
||||||
!!! note "Model Support"
|
|
||||||
Currently only DeepSeek V3 architecture is supported.
|
|
||||||
|
|
||||||
When enabled, vLLM collects load statistics with every forward pass and periodically rebalances expert distribution.
|
When enabled, vLLM collects load statistics with every forward pass and periodically rebalances expert distribution.
|
||||||
|
|
||||||
### EPLB Parameters
|
### EPLB Parameters
|
||||||
@ -134,6 +133,8 @@ Configure EPLB with the `--eplb-config` argument, which accepts a JSON string. T
|
|||||||
| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 |
|
| `step_interval`| Frequency of rebalancing (every N engine steps) | 3000 |
|
||||||
| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` |
|
| `log_balancedness` | Log balancedness metrics (avg tokens per expert ÷ max tokens per expert) | `false` |
|
||||||
| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` |
|
| `num_redundant_experts` | Additional global experts per EP rank beyond equal distribution | `0` |
|
||||||
|
| `use_async` | Use non-blocking EPLB for reduced latency overhead | `false` |
|
||||||
|
| `policy` | The policy type for expert parallel load balancing | `"default"` |
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
@ -183,6 +184,26 @@ vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
|||||||
|
|
||||||
For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available.
|
For multi-node deployment, add these EPLB flags to each node's command. We recommend setting `--eplb-config '{"num_redundant_experts":32}'` to 32 in large scale use cases so the most popular experts are always available.
|
||||||
|
|
||||||
|
## Advanced Configuration
|
||||||
|
|
||||||
|
### Performance Optimization
|
||||||
|
|
||||||
|
- **DeepEP kernels**: The `high_throughput` and `low_latency` kernels are optimized for disaggregated serving and may show poor performance for mixed workloads
|
||||||
|
- **Dual Batch Overlap**: Use `--enable-dbo` to overlap all-to-all communication with compute. See [Dual Batch Overlap](../design/dbo.md) for more details.
|
||||||
|
- **Async scheduling (experimental)**: Try `--async-scheduling` to overlap scheduling with model execution.
|
||||||
|
|
||||||
|
### Troubleshooting
|
||||||
|
|
||||||
|
- **`non-zero status: 7 cannot register cq buf`**: When using Infiniband/RoCE, make sure host VM and pods show `ulimit -l` "unlimited".
|
||||||
|
- **`init failed for transport: IBGDA`**: The InfiniBand GDA kernel modules are missing. Run `tools/ep_kernels/configure_system_drivers.sh` on each GPU node and reboot. Also fixes error `NVSHMEM API called before NVSHMEM initialization has completed`.
|
||||||
|
- **NVSHMEM peer disconnect**: Usually a networking misconfiguration. If deploying via Kubernetes, verify that every pod runs with `hostNetwork: true`, `securityContext.privileged: true` to access Infiniband.
|
||||||
|
|
||||||
|
### Benchmarking
|
||||||
|
|
||||||
|
- Use simulator flags `VLLM_MOE_ROUTING_SIMULATION_STRATEGY=uniform_random` and `VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1` so token routing is balanced across EP ranks.
|
||||||
|
|
||||||
|
- Increasing `VLLM_MOE_DP_CHUNK_SIZE` may increase throughput by increasing the maximum batch size for inter-rank token transfers. This may cause DeepEP to throw `assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2`, which can be fixed by increasing environment variable `NVSHMEM_QP_DEPTH`.
|
||||||
|
|
||||||
## Disaggregated Serving (Prefill/Decode Split)
|
## Disaggregated Serving (Prefill/Decode Split)
|
||||||
|
|
||||||
For production deployments requiring strict SLA guarantees for time-to-first-token and inter-token latency, disaggregated serving allows independent scaling of prefill and decode operations.
|
For production deployments requiring strict SLA guarantees for time-to-first-token and inter-token latency, disaggregated serving allows independent scaling of prefill and decode operations.
|
||||||
@ -273,3 +294,9 @@ except Exception as e:
|
|||||||
print(f"❌ Error during disaggregated serving: {e}")
|
print(f"❌ Error during disaggregated serving: {e}")
|
||||||
print("Check that both prefill and decode instances are running and accessible")
|
print("Check that both prefill and decode instances are running and accessible")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Benchmarking
|
||||||
|
|
||||||
|
- To simulate the decode deployment of disaggregated serving, pass `--kv-transfer-config '{"kv_connector":"DecodeBenchConnector","kv_role":"kv_both"}'` to the `vllm serve` invocation. The connector populates KV cache with random values so decode can be profiled in isolation.
|
||||||
|
|
||||||
|
- **CUDAGraph capture**: Use `--compilation_config '{"cudagraph_mode": "FULL_DECODE_ONLY"}'` to enable CUDA graph capture for decode only and save KV cache.
|
||||||
|
|||||||
@ -33,11 +33,19 @@ Then query the endpoint to get the latest metrics from the server:
|
|||||||
|
|
||||||
The following metrics are exposed:
|
The following metrics are exposed:
|
||||||
|
|
||||||
??? code
|
## General Metrics
|
||||||
|
|
||||||
```python
|
--8<-- "docs/generated/metrics/general.md"
|
||||||
--8<-- "vllm/engine/metrics.py:metrics-definitions"
|
|
||||||
```
|
## Speculative Decoding Metrics
|
||||||
|
|
||||||
|
--8<-- "docs/generated/metrics/spec_decode.md"
|
||||||
|
|
||||||
|
## NIXL KV Connector Metrics
|
||||||
|
|
||||||
|
--8<-- "docs/generated/metrics/nixl_connector.md"
|
||||||
|
|
||||||
|
## Deprecation Policy
|
||||||
|
|
||||||
Note: when metrics are deprecated in version `X.Y`, they are hidden in version `X.Y+1`
|
Note: when metrics are deprecated in version `X.Y`, they are hidden in version `X.Y+1`
|
||||||
but can be re-enabled using the `--show-hidden-metrics-for-version=X.Y` escape hatch,
|
but can be re-enabled using the `--show-hidden-metrics-for-version=X.Y` escape hatch,
|
||||||
|
|||||||
@ -422,7 +422,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=0,
|
||||||
help="Set the seed when initializing `vllm.LLM`.",
|
help="Set the seed when initializing `vllm.LLM`.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@ -77,7 +77,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=0,
|
||||||
help="Set the seed when initializing `vllm.LLM`.",
|
help="Set the seed when initializing `vllm.LLM`.",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|||||||
@ -158,7 +158,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=0,
|
||||||
help="Set the seed when initializing `vllm.LLM`.",
|
help="Set the seed when initializing `vllm.LLM`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -158,7 +158,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=0,
|
||||||
help="Set the seed when initializing `vllm.LLM`.",
|
help="Set the seed when initializing `vllm.LLM`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -2031,7 +2031,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=0,
|
||||||
help="Set the seed when initializing `vllm.LLM`.",
|
help="Set the seed when initializing `vllm.LLM`.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1382,7 +1382,7 @@ def run_generate(
|
|||||||
model,
|
model,
|
||||||
question: str,
|
question: str,
|
||||||
image_urls: list[str],
|
image_urls: list[str],
|
||||||
seed: int | None,
|
seed: int,
|
||||||
tensor_parallel_size: int | None,
|
tensor_parallel_size: int | None,
|
||||||
):
|
):
|
||||||
req_data = model_example_map[model](question, image_urls)
|
req_data = model_example_map[model](question, image_urls)
|
||||||
@ -1416,7 +1416,7 @@ def run_chat(
|
|||||||
model: str,
|
model: str,
|
||||||
question: str,
|
question: str,
|
||||||
image_urls: list[str],
|
image_urls: list[str],
|
||||||
seed: int | None,
|
seed: int,
|
||||||
tensor_parallel_size: int | None,
|
tensor_parallel_size: int | None,
|
||||||
):
|
):
|
||||||
req_data = model_example_map[model](question, image_urls)
|
req_data = model_example_map[model](question, image_urls)
|
||||||
@ -1494,7 +1494,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=0,
|
||||||
help="Set the seed when initializing `vllm.LLM`.",
|
help="Set the seed when initializing `vllm.LLM`.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import requests
|
|||||||
# - start vllm in serving mode with the below args
|
# - start vllm in serving mode with the below args
|
||||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||||
# --model-impl terratorch
|
# --model-impl terratorch
|
||||||
# --task embed --trust-remote-code
|
# --trust-remote-code
|
||||||
# --skip-tokenizer-init --enforce-eager
|
# --skip-tokenizer-init --enforce-eager
|
||||||
# --io-processor-plugin terratorch_segmentation
|
# --io-processor-plugin terratorch_segmentation
|
||||||
# --enable-mm-embeds
|
# --enable-mm-embeds
|
||||||
|
|||||||
@ -305,7 +305,7 @@ def get_query(modality: QueryModality):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
def run_encode(model: str, modality: QueryModality, seed: int | None):
|
def run_encode(model: str, modality: QueryModality, seed: int):
|
||||||
query = get_query(modality)
|
query = get_query(modality)
|
||||||
req_data = model_example_map[model](query)
|
req_data = model_example_map[model](query)
|
||||||
|
|
||||||
@ -335,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: int | None):
|
|||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
def run_score(model: str, modality: QueryModality, seed: int | None):
|
def run_score(model: str, modality: QueryModality, seed: int):
|
||||||
query = get_query(modality)
|
query = get_query(modality)
|
||||||
req_data = model_example_map[model](query)
|
req_data = model_example_map[model](query)
|
||||||
|
|
||||||
@ -390,7 +390,7 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=0,
|
||||||
help="Set the seed when initializing `vllm.LLM`.",
|
help="Set the seed when initializing `vllm.LLM`.",
|
||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|||||||
@ -51,6 +51,7 @@ hooks:
|
|||||||
- docs/mkdocs/hooks/remove_announcement.py
|
- docs/mkdocs/hooks/remove_announcement.py
|
||||||
- docs/mkdocs/hooks/generate_examples.py
|
- docs/mkdocs/hooks/generate_examples.py
|
||||||
- docs/mkdocs/hooks/generate_argparse.py
|
- docs/mkdocs/hooks/generate_argparse.py
|
||||||
|
- docs/mkdocs/hooks/generate_metrics.py
|
||||||
- docs/mkdocs/hooks/url_schemes.py
|
- docs/mkdocs/hooks/url_schemes.py
|
||||||
|
|
||||||
plugins:
|
plugins:
|
||||||
|
|||||||
@ -1,2 +1,2 @@
|
|||||||
lmcache
|
lmcache >= 0.3.10.post1
|
||||||
nixl >= 0.7.1 # Required for disaggregated prefill
|
nixl >= 0.7.1 # Required for disaggregated prefill
|
||||||
|
|||||||
@ -75,7 +75,7 @@ torchgeo==0.7.0
|
|||||||
mteb==2.1.2
|
mteb==2.1.2
|
||||||
|
|
||||||
# Data processing
|
# Data processing
|
||||||
xgrammar==0.1.27
|
xgrammar @ git+https://github.com/divakar-amd/xgrammar@3272f7c520564858056a60480d5afdf69ae79c84
|
||||||
# Test async scheduling
|
# Test async scheduling
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import copy
|
import copy
|
||||||
import logging
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
@ -13,7 +12,6 @@ from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
|||||||
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
|
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
|
||||||
from vllm.config.compilation import CompilationMode, PassConfig
|
from vllm.config.compilation import CompilationMode, PassConfig
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.logger import _print_warning_once
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
from vllm.utils.torch_utils import _is_torch_equal_or_newer
|
||||||
|
|
||||||
@ -290,7 +288,7 @@ def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
|
|||||||
),
|
),
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config={"enable_attn_fusion": True, "enable_noop": True},
|
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
|
||||||
custom_ops=["+quant_fp8"],
|
custom_ops=["+quant_fp8"],
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||||
),
|
),
|
||||||
@ -442,62 +440,3 @@ def test_cudagraph_sizes_post_init(
|
|||||||
vllm_config.compilation_config.max_cudagraph_capture_size
|
vllm_config.compilation_config.max_cudagraph_capture_size
|
||||||
== expected_max_size
|
== expected_max_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_pass_config_deprecation(caplog_vllm):
|
|
||||||
caplog_vllm.set_level(logging.WARNING)
|
|
||||||
|
|
||||||
# Clear cache to ensure warnings are re-issued
|
|
||||||
_print_warning_once.cache_clear()
|
|
||||||
|
|
||||||
# Test enable_fusion -> fuse_norm_quant, fuse_act_quant
|
|
||||||
caplog_vllm.clear()
|
|
||||||
config = PassConfig(enable_fusion=True)
|
|
||||||
assert "enable_fusion is deprecated" in caplog_vllm.text
|
|
||||||
assert config.fuse_norm_quant is True
|
|
||||||
assert config.fuse_act_quant is True
|
|
||||||
assert config.enable_fusion is True
|
|
||||||
|
|
||||||
# Test enable_attn_fusion -> fuse_attn_quant
|
|
||||||
caplog_vllm.clear()
|
|
||||||
config = PassConfig(enable_attn_fusion=True)
|
|
||||||
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
|
|
||||||
assert config.fuse_attn_quant is True
|
|
||||||
assert config.enable_attn_fusion is True
|
|
||||||
|
|
||||||
# Test enable_noop -> eliminate_noops
|
|
||||||
caplog_vllm.clear()
|
|
||||||
config = PassConfig(enable_noop=True)
|
|
||||||
assert "enable_noop is deprecated" in caplog_vllm.text
|
|
||||||
assert config.eliminate_noops is True
|
|
||||||
assert config.enable_noop is True
|
|
||||||
|
|
||||||
# Test enable_sequence_parallelism -> enable_sp
|
|
||||||
caplog_vllm.clear()
|
|
||||||
config = PassConfig(enable_sequence_parallelism=True)
|
|
||||||
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
|
|
||||||
assert config.enable_sp is True
|
|
||||||
assert config.enable_sequence_parallelism is True
|
|
||||||
|
|
||||||
# Test enable_async_tp -> fuse_gemm_comms
|
|
||||||
caplog_vllm.clear()
|
|
||||||
config = PassConfig(enable_async_tp=True)
|
|
||||||
assert "enable_async_tp is deprecated" in caplog_vllm.text
|
|
||||||
assert config.fuse_gemm_comms is True
|
|
||||||
assert config.enable_async_tp is True
|
|
||||||
|
|
||||||
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
|
|
||||||
caplog_vllm.clear()
|
|
||||||
config = PassConfig(enable_fi_allreduce_fusion=True)
|
|
||||||
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
|
|
||||||
assert config.fuse_allreduce_rms is True
|
|
||||||
assert config.enable_fi_allreduce_fusion is True
|
|
||||||
|
|
||||||
# Test hash consistency
|
|
||||||
config_old = PassConfig(enable_fusion=True)
|
|
||||||
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
|
|
||||||
assert config_old.compute_hash() == config_new.compute_hash()
|
|
||||||
|
|
||||||
config_old = PassConfig(enable_async_tp=True)
|
|
||||||
config_new = PassConfig(fuse_gemm_comms=True)
|
|
||||||
assert config_old.compute_hash() == config_new.compute_hash()
|
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
|
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
|
||||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||||
from vllm.compilation.fx_utils import find_op_nodes
|
from vllm.compilation.fx_utils import find_op_nodes
|
||||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||||
@ -152,13 +155,79 @@ GROUP_SHAPES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRmsnormGroupFp8QuantModel(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, eps: float, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||||
|
weight_group_shape=GroupShape(128, 128),
|
||||||
|
act_quant_group_shape=GroupShape(1, 128),
|
||||||
|
cutlass_block_fp8_supported=False,
|
||||||
|
use_aiter_and_is_supported=True,
|
||||||
|
)
|
||||||
|
self.w = [
|
||||||
|
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||||
|
self.wscale = [
|
||||||
|
torch.rand((scale_hidden_size, scale_hidden_size), dtype=torch.float32)
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.norm_weight = [torch.ones(hidden_size) for _ in range(4)]
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# avoid having graph input be an arg to a pattern directly
|
||||||
|
x = resid = torch.relu(x)
|
||||||
|
y = rocm_aiter_ops.rms_norm(x, self.norm_weight[0], self.eps)
|
||||||
|
|
||||||
|
x2 = self.w8a8_block_fp8_linear.apply(y, self.w[0], self.wscale[0])
|
||||||
|
# make sure resid is used for replacement to work
|
||||||
|
y2, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||||
|
x2, resid, self.norm_weight[1], self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
x3 = self.w8a8_block_fp8_linear.apply(y2, self.w[1], self.wscale[1])
|
||||||
|
|
||||||
|
y3, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||||
|
x3, resid, self.norm_weight[2], self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
x4 = self.w8a8_block_fp8_linear.apply(y3, self.w[2], self.wscale[2])
|
||||||
|
|
||||||
|
y4, resid = rocm_aiter_ops.rms_norm2d_with_add(
|
||||||
|
x4, resid, self.norm_weight[3], self.eps
|
||||||
|
)
|
||||||
|
return y4
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.rocm_aiter_rms_norm,
|
||||||
|
torch.ops.vllm.rocm_aiter_group_fp8_quant,
|
||||||
|
]
|
||||||
|
|
||||||
|
def ops_in_model_before_partial(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [
|
||||||
|
torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant,
|
||||||
|
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
@pytest.mark.parametrize("hidden_size", [256])
|
@pytest.mark.parametrize("hidden_size", [256])
|
||||||
@pytest.mark.parametrize("num_tokens", [257])
|
@pytest.mark.parametrize("num_tokens", [257])
|
||||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
@pytest.mark.parametrize("group_shape", GROUP_SHAPES)
|
||||||
@pytest.mark.parametrize("enable_rms_norm_custom_op", [True, False])
|
@pytest.mark.parametrize(
|
||||||
@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False])
|
"model_class, enable_rms_norm_custom_op, enable_quant_fp8_custom_op",
|
||||||
|
list(itertools.product([TestModel], [True, False], [True, False]))
|
||||||
|
+ [(TestRmsnormGroupFp8QuantModel, False, False)],
|
||||||
|
)
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -173,10 +242,14 @@ def test_fusion_rmsnorm_quant(
|
|||||||
num_tokens,
|
num_tokens,
|
||||||
eps,
|
eps,
|
||||||
group_shape,
|
group_shape,
|
||||||
|
model_class,
|
||||||
enable_rms_norm_custom_op,
|
enable_rms_norm_custom_op,
|
||||||
enable_quant_fp8_custom_op,
|
enable_quant_fp8_custom_op,
|
||||||
cuda_force_torch,
|
cuda_force_torch,
|
||||||
):
|
):
|
||||||
|
if model_class is TestRmsnormGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||||
|
pytest.skip("AITER is not supported on this GPU.")
|
||||||
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
@ -209,12 +282,24 @@ def test_fusion_rmsnorm_quant(
|
|||||||
with vllm.config.set_current_vllm_config(vllm_config):
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
noop_pass = NoOpEliminationPass(vllm_config)
|
noop_pass = NoOpEliminationPass(vllm_config)
|
||||||
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
if model_class is TestRmsnormGroupFp8QuantModel:
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
|
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||||
|
)
|
||||||
|
|
||||||
|
fusion_pass = RocmAiterRMSNormFp8GroupQuantFusionPass(vllm_config)
|
||||||
|
else:
|
||||||
|
fusion_pass = RMSNormQuantFusionPass(vllm_config)
|
||||||
cleanup_pass = PostCleanupPass(vllm_config)
|
cleanup_pass = PostCleanupPass(vllm_config)
|
||||||
|
|
||||||
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
backend = TestBackend(noop_pass, fusion_pass, cleanup_pass)
|
||||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||||
model = TestModel(hidden_size, eps, group_shape, cuda_force_torch)
|
model = model_class(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
eps=eps,
|
||||||
|
group_shape=group_shape,
|
||||||
|
cuda_force_torch=cuda_force_torch,
|
||||||
|
)
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
x = torch.rand(num_tokens, hidden_size)
|
x = torch.rand(num_tokens, hidden_size)
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
@ -243,7 +328,10 @@ def test_fusion_rmsnorm_quant(
|
|||||||
# there's a risk that the fused add doesn't get included in the
|
# there's a risk that the fused add doesn't get included in the
|
||||||
# replacement and only the rms part gets fused with quant.
|
# replacement and only the rms part gets fused with quant.
|
||||||
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
# Hence, we check only 2 add nodes are left (final fused rmsnorm add).
|
||||||
if not enable_rms_norm_custom_op:
|
if (
|
||||||
|
not enable_rms_norm_custom_op
|
||||||
|
and model_class is not TestRmsnormGroupFp8QuantModel
|
||||||
|
):
|
||||||
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
n_add_nodes = lambda g: sum(1 for _ in find_op_nodes(torch.ops.aten.add, g))
|
||||||
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
# 7 = 1 (RMS) + 3x2 (3xRMS_ADD, 2 each)
|
||||||
assert n_add_nodes(backend.graph_pre_pass) == 7
|
assert n_add_nodes(backend.graph_pre_pass) == 7
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||||
|
from vllm._aiter_ops import IS_AITER_FOUND
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.compilation.activation_quant_fusion import (
|
from vllm.compilation.activation_quant_fusion import (
|
||||||
FUSED_OPS,
|
FUSED_OPS,
|
||||||
@ -24,6 +25,7 @@ from vllm.config import (
|
|||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
@ -126,6 +128,39 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
return [FUSED_OPS[kNvfp4Quant]]
|
return [FUSED_OPS[kNvfp4Quant]]
|
||||||
|
|
||||||
|
|
||||||
|
class TestSiluMulGroupFp8QuantModel(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.silu_and_mul = SiluAndMul()
|
||||||
|
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||||
|
weight_group_shape=GroupShape(128, 128),
|
||||||
|
act_quant_group_shape=GroupShape(1, 128),
|
||||||
|
cutlass_block_fp8_supported=False,
|
||||||
|
use_aiter_and_is_supported=True,
|
||||||
|
)
|
||||||
|
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
|
|
||||||
|
scale_hidden_size = (hidden_size + 128 - 1) // 128
|
||||||
|
self.wscale = torch.rand(
|
||||||
|
(scale_hidden_size, scale_hidden_size), dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.silu_and_mul(x)
|
||||||
|
x2 = self.w8a8_block_fp8_linear.apply(y, self.w, self.wscale)
|
||||||
|
return x2
|
||||||
|
|
||||||
|
def ops_in_model_before(self):
|
||||||
|
return [
|
||||||
|
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
|
||||||
|
]
|
||||||
|
|
||||||
|
def ops_in_model_after(self):
|
||||||
|
return [torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", [32, 64])
|
@pytest.mark.parametrize("num_tokens", [32, 64])
|
||||||
@pytest.mark.parametrize("hidden_size", [128, 256])
|
@pytest.mark.parametrize("hidden_size", [128, 256])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@ -133,7 +168,10 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
|
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
|
||||||
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
|
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
|
||||||
+ [(TestSiluMulNvfp4QuantModel, False, False)],
|
+ [
|
||||||
|
(TestSiluMulNvfp4QuantModel, False, False),
|
||||||
|
(TestSiluMulGroupFp8QuantModel, False, False),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@ -144,13 +182,19 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
|
model_class: type[
|
||||||
|
TestSiluMulFp8QuantModel
|
||||||
|
| TestSiluMulNvfp4QuantModel
|
||||||
|
| TestSiluMulGroupFp8QuantModel
|
||||||
|
],
|
||||||
enable_silu_mul_custom_op: bool,
|
enable_silu_mul_custom_op: bool,
|
||||||
enable_quant_fp8_custom_op: bool,
|
enable_quant_fp8_custom_op: bool,
|
||||||
cuda_force_torch: bool,
|
cuda_force_torch: bool,
|
||||||
):
|
):
|
||||||
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
|
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
|
||||||
pytest.skip("NVFP4 is not supported on this GPU.")
|
pytest.skip("NVFP4 is not supported on this GPU.")
|
||||||
|
if model_class is TestSiluMulGroupFp8QuantModel and not IS_AITER_FOUND:
|
||||||
|
pytest.skip("AITER is not supported on this GPU.")
|
||||||
|
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(dtype)
|
torch.set_default_dtype(dtype)
|
||||||
@ -173,9 +217,15 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
)
|
)
|
||||||
|
|
||||||
with set_current_vllm_config(config):
|
with set_current_vllm_config(config):
|
||||||
fusion_pass = ActivationQuantFusionPass(config)
|
fusion_passes = [ActivationQuantFusionPass(config)]
|
||||||
|
if IS_AITER_FOUND:
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
|
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||||
|
)
|
||||||
|
|
||||||
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
|
fusion_passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||||
|
|
||||||
|
passes = [NoOpEliminationPass(config), *fusion_passes, PostCleanupPass(config)]
|
||||||
backend = TestBackend(*passes)
|
backend = TestBackend(*passes)
|
||||||
model = model_class(
|
model = model_class(
|
||||||
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
|
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
|
||||||
@ -194,12 +244,14 @@ def test_fusion_silu_and_mul_quant(
|
|||||||
atol, rtol = 1e-3, 1e-3
|
atol, rtol = 1e-3, 1e-3
|
||||||
elif model_class == TestSiluMulNvfp4QuantModel:
|
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||||
atol, rtol = 1e-1, 1e-1
|
atol, rtol = 1e-1, 1e-1
|
||||||
|
elif model_class == TestSiluMulGroupFp8QuantModel:
|
||||||
|
atol, rtol = 5e-2, 5e-2
|
||||||
|
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
|
||||||
)
|
)
|
||||||
|
|
||||||
assert fusion_pass.matched_count == 1
|
assert sum([p.matched_count for p in fusion_passes]) == 1
|
||||||
|
|
||||||
# In pre-nodes, quant op should be present and fused kernels should not
|
# In pre-nodes, quant op should be present and fused kernels should not
|
||||||
backend.check_before_ops(model.ops_in_model_before())
|
backend.check_before_ops(model.ops_in_model_before())
|
||||||
|
|||||||
@ -741,7 +741,7 @@ class VllmRunner:
|
|||||||
tokenizer_name: str | None = None,
|
tokenizer_name: str | None = None,
|
||||||
tokenizer_mode: str = "auto",
|
tokenizer_mode: str = "auto",
|
||||||
trust_remote_code: bool = True,
|
trust_remote_code: bool = True,
|
||||||
seed: int | None = 0,
|
seed: int = 0,
|
||||||
max_model_len: int | None = 1024,
|
max_model_len: int | None = 1024,
|
||||||
dtype: str = "auto",
|
dtype: str = "auto",
|
||||||
disable_log_stats: bool = True,
|
disable_log_stats: bool = True,
|
||||||
|
|||||||
@ -350,21 +350,35 @@ def test_human_readable_model_len():
|
|||||||
assert args.max_model_len == 1_000_000
|
assert args.max_model_len == 1_000_000
|
||||||
args = parser.parse_args(["--max-model-len", "10k"])
|
args = parser.parse_args(["--max-model-len", "10k"])
|
||||||
assert args.max_model_len == 10_000
|
assert args.max_model_len == 10_000
|
||||||
|
args = parser.parse_args(["--max-model-len", "2g"])
|
||||||
|
assert args.max_model_len == 2_000_000_000
|
||||||
|
args = parser.parse_args(["--max-model-len", "2t"])
|
||||||
|
assert args.max_model_len == 2_000_000_000_000
|
||||||
|
|
||||||
# Capital
|
# Capital
|
||||||
args = parser.parse_args(["--max-model-len", "3K"])
|
args = parser.parse_args(["--max-model-len", "3K"])
|
||||||
assert args.max_model_len == 1024 * 3
|
assert args.max_model_len == 2**10 * 3
|
||||||
args = parser.parse_args(["--max-model-len", "10M"])
|
args = parser.parse_args(["--max-model-len", "10M"])
|
||||||
assert args.max_model_len == 2**20 * 10
|
assert args.max_model_len == 2**20 * 10
|
||||||
|
args = parser.parse_args(["--max-model-len", "4G"])
|
||||||
|
assert args.max_model_len == 2**30 * 4
|
||||||
|
args = parser.parse_args(["--max-model-len", "4T"])
|
||||||
|
assert args.max_model_len == 2**40 * 4
|
||||||
|
|
||||||
# Decimal values
|
# Decimal values
|
||||||
args = parser.parse_args(["--max-model-len", "10.2k"])
|
args = parser.parse_args(["--max-model-len", "10.2k"])
|
||||||
assert args.max_model_len == 10200
|
assert args.max_model_len == 10200
|
||||||
# ..truncated to the nearest int
|
# ..truncated to the nearest int
|
||||||
args = parser.parse_args(["--max-model-len", "10.212345k"])
|
args = parser.parse_args(["--max-model-len", "10.2123451234567k"])
|
||||||
assert args.max_model_len == 10212
|
assert args.max_model_len == 10212
|
||||||
|
args = parser.parse_args(["--max-model-len", "10.2123451234567m"])
|
||||||
|
assert args.max_model_len == 10212345
|
||||||
|
args = parser.parse_args(["--max-model-len", "10.2123451234567g"])
|
||||||
|
assert args.max_model_len == 10212345123
|
||||||
|
args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
|
||||||
|
assert args.max_model_len == 10212345123456
|
||||||
|
|
||||||
# Invalid (do not allow decimals with binary multipliers)
|
# Invalid (do not allow decimals with binary multipliers)
|
||||||
for invalid in ["1a", "pwd", "10.24", "1.23M"]:
|
for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
|
||||||
with pytest.raises(ArgumentError):
|
with pytest.raises(ArgumentError):
|
||||||
args = parser.parse_args(["--max-model-len", invalid])
|
parser.parse_args(["--max-model-len", invalid])
|
||||||
|
|||||||
228
tests/entrypoints/openai/test_chat_error.py
Normal file
228
tests/entrypoints/openai/test_chat_error.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config.multimodal import MultiModalConfig
|
||||||
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, ErrorResponse
|
||||||
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||||
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
|
MODEL_NAME_SHORT = "gpt2"
|
||||||
|
BASE_MODEL_PATHS = [
|
||||||
|
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||||
|
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockHFConfig:
|
||||||
|
model_type: str = "any"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockModelConfig:
|
||||||
|
task = "generate"
|
||||||
|
runner_type = "generate"
|
||||||
|
tokenizer = MODEL_NAME
|
||||||
|
trust_remote_code = False
|
||||||
|
tokenizer_mode = "auto"
|
||||||
|
max_model_len = 100
|
||||||
|
tokenizer_revision = None
|
||||||
|
multimodal_config = MultiModalConfig()
|
||||||
|
hf_config = MockHFConfig()
|
||||||
|
logits_processor_pattern = None
|
||||||
|
logits_processors: list[str] | None = None
|
||||||
|
diff_sampling_param: dict | None = None
|
||||||
|
allowed_local_media_path: str = ""
|
||||||
|
allowed_media_domains: list[str] | None = None
|
||||||
|
encoder_config = None
|
||||||
|
generation_config: str = "auto"
|
||||||
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
skip_tokenizer_init = False
|
||||||
|
|
||||||
|
def get_diff_sampling_param(self):
|
||||||
|
return self.diff_sampling_param or {}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||||
|
models = OpenAIServingModels(
|
||||||
|
engine_client=engine,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
|
)
|
||||||
|
serving_chat = OpenAIServingChat(
|
||||||
|
engine,
|
||||||
|
models,
|
||||||
|
response_role="assistant",
|
||||||
|
request_logger=None,
|
||||||
|
chat_template=None,
|
||||||
|
chat_template_content_format="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_process_inputs(
|
||||||
|
request_id,
|
||||||
|
engine_prompt,
|
||||||
|
sampling_params,
|
||||||
|
*,
|
||||||
|
lora_request,
|
||||||
|
trace_headers,
|
||||||
|
priority,
|
||||||
|
):
|
||||||
|
return dict(engine_prompt), {}
|
||||||
|
|
||||||
|
async def _fake_preprocess_chat(*args, **kwargs):
|
||||||
|
# return conversation, request_prompts, engine_prompts
|
||||||
|
return (
|
||||||
|
[{"role": "user", "content": "Test"}],
|
||||||
|
[[1, 2, 3]],
|
||||||
|
[{"prompt_token_ids": [1, 2, 3]}],
|
||||||
|
)
|
||||||
|
|
||||||
|
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||||
|
serving_chat._preprocess_chat = AsyncMock(side_effect=_fake_preprocess_chat)
|
||||||
|
return serving_chat
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_error_non_stream():
|
||||||
|
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||||
|
mock_engine = MagicMock(spec=AsyncLLM)
|
||||||
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
mock_engine.errored = False
|
||||||
|
mock_engine.model_config = MockModelConfig()
|
||||||
|
mock_engine.input_processor = MagicMock()
|
||||||
|
mock_engine.io_processor = MagicMock()
|
||||||
|
|
||||||
|
serving_chat = _build_serving_chat(mock_engine)
|
||||||
|
|
||||||
|
completion_output = CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="",
|
||||||
|
token_ids=[],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
request_output = RequestOutput(
|
||||||
|
request_id="test-id",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[completion_output],
|
||||||
|
finished=True,
|
||||||
|
metrics=None,
|
||||||
|
lora_request=None,
|
||||||
|
encoder_prompt=None,
|
||||||
|
encoder_prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_generate(*args, **kwargs):
|
||||||
|
yield request_output
|
||||||
|
|
||||||
|
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{"role": "user", "content": "Test prompt"}],
|
||||||
|
max_tokens=10,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await serving_chat.create_chat_completion(request)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert response.error.type == "InternalServerError"
|
||||||
|
assert response.error.message == "Internal server error"
|
||||||
|
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_error_stream():
|
||||||
|
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||||
|
mock_engine = MagicMock(spec=AsyncLLM)
|
||||||
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
mock_engine.errored = False
|
||||||
|
mock_engine.model_config = MockModelConfig()
|
||||||
|
mock_engine.input_processor = MagicMock()
|
||||||
|
mock_engine.io_processor = MagicMock()
|
||||||
|
|
||||||
|
serving_chat = _build_serving_chat(mock_engine)
|
||||||
|
|
||||||
|
completion_output_1 = CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="Hello",
|
||||||
|
token_ids=[100],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_output_1 = RequestOutput(
|
||||||
|
request_id="test-id",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[completion_output_1],
|
||||||
|
finished=False,
|
||||||
|
metrics=None,
|
||||||
|
lora_request=None,
|
||||||
|
encoder_prompt=None,
|
||||||
|
encoder_prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_output_2 = CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="Hello",
|
||||||
|
token_ids=[100],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
request_output_2 = RequestOutput(
|
||||||
|
request_id="test-id",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[completion_output_2],
|
||||||
|
finished=True,
|
||||||
|
metrics=None,
|
||||||
|
lora_request=None,
|
||||||
|
encoder_prompt=None,
|
||||||
|
encoder_prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_generate(*args, **kwargs):
|
||||||
|
yield request_output_1
|
||||||
|
yield request_output_2
|
||||||
|
|
||||||
|
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{"role": "user", "content": "Test prompt"}],
|
||||||
|
max_tokens=10,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await serving_chat.create_chat_completion(request)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk in response:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||||
|
f"Expected error message in chunks: {chunks}"
|
||||||
|
)
|
||||||
|
assert chunks[-1] == "data: [DONE]\n\n"
|
||||||
216
tests/entrypoints/openai/test_completion_error.py
Normal file
216
tests/entrypoints/openai/test_completion_error.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config.multimodal import MultiModalConfig
|
||||||
|
from vllm.entrypoints.openai.protocol import CompletionRequest, ErrorResponse
|
||||||
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
|
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
|
||||||
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
from vllm.v1.engine.async_llm import AsyncLLM
|
||||||
|
|
||||||
|
MODEL_NAME = "openai-community/gpt2"
|
||||||
|
MODEL_NAME_SHORT = "gpt2"
|
||||||
|
BASE_MODEL_PATHS = [
|
||||||
|
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
||||||
|
BaseModelPath(name=MODEL_NAME_SHORT, model_path=MODEL_NAME_SHORT),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockHFConfig:
|
||||||
|
model_type: str = "any"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockModelConfig:
|
||||||
|
task = "generate"
|
||||||
|
runner_type = "generate"
|
||||||
|
tokenizer = MODEL_NAME
|
||||||
|
trust_remote_code = False
|
||||||
|
tokenizer_mode = "auto"
|
||||||
|
max_model_len = 100
|
||||||
|
tokenizer_revision = None
|
||||||
|
multimodal_config = MultiModalConfig()
|
||||||
|
hf_config = MockHFConfig()
|
||||||
|
logits_processor_pattern = None
|
||||||
|
logits_processors: list[str] | None = None
|
||||||
|
diff_sampling_param: dict | None = None
|
||||||
|
allowed_local_media_path: str = ""
|
||||||
|
allowed_media_domains: list[str] | None = None
|
||||||
|
encoder_config = None
|
||||||
|
generation_config: str = "auto"
|
||||||
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||||
|
skip_tokenizer_init = False
|
||||||
|
|
||||||
|
def get_diff_sampling_param(self):
|
||||||
|
return self.diff_sampling_param or {}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||||
|
models = OpenAIServingModels(
|
||||||
|
engine_client=engine,
|
||||||
|
base_model_paths=BASE_MODEL_PATHS,
|
||||||
|
)
|
||||||
|
serving_completion = OpenAIServingCompletion(
|
||||||
|
engine,
|
||||||
|
models,
|
||||||
|
request_logger=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _fake_process_inputs(
|
||||||
|
request_id,
|
||||||
|
engine_prompt,
|
||||||
|
sampling_params,
|
||||||
|
*,
|
||||||
|
lora_request,
|
||||||
|
trace_headers,
|
||||||
|
priority,
|
||||||
|
):
|
||||||
|
return dict(engine_prompt), {}
|
||||||
|
|
||||||
|
serving_completion._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
|
||||||
|
return serving_completion
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_error_non_stream():
|
||||||
|
"""test finish_reason='error' returns 500 InternalServerError (non-streaming)"""
|
||||||
|
mock_engine = MagicMock(spec=AsyncLLM)
|
||||||
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
mock_engine.errored = False
|
||||||
|
mock_engine.model_config = MockModelConfig()
|
||||||
|
mock_engine.input_processor = MagicMock()
|
||||||
|
mock_engine.io_processor = MagicMock()
|
||||||
|
|
||||||
|
serving_completion = _build_serving_completion(mock_engine)
|
||||||
|
|
||||||
|
completion_output = CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="",
|
||||||
|
token_ids=[],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
request_output = RequestOutput(
|
||||||
|
request_id="test-id",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[completion_output],
|
||||||
|
finished=True,
|
||||||
|
metrics=None,
|
||||||
|
lora_request=None,
|
||||||
|
encoder_prompt=None,
|
||||||
|
encoder_prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_generate(*args, **kwargs):
|
||||||
|
yield request_output
|
||||||
|
|
||||||
|
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||||
|
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt="Test prompt",
|
||||||
|
max_tokens=10,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await serving_completion.create_completion(request)
|
||||||
|
|
||||||
|
assert isinstance(response, ErrorResponse)
|
||||||
|
assert response.error.type == "InternalServerError"
|
||||||
|
assert response.error.message == "Internal server error"
|
||||||
|
assert response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_error_stream():
|
||||||
|
"""test finish_reason='error' returns 500 InternalServerError (streaming)"""
|
||||||
|
mock_engine = MagicMock(spec=AsyncLLM)
|
||||||
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
mock_engine.errored = False
|
||||||
|
mock_engine.model_config = MockModelConfig()
|
||||||
|
mock_engine.input_processor = MagicMock()
|
||||||
|
mock_engine.io_processor = MagicMock()
|
||||||
|
|
||||||
|
serving_completion = _build_serving_completion(mock_engine)
|
||||||
|
|
||||||
|
completion_output_1 = CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="Hello",
|
||||||
|
token_ids=[100],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
request_output_1 = RequestOutput(
|
||||||
|
request_id="test-id",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[completion_output_1],
|
||||||
|
finished=False,
|
||||||
|
metrics=None,
|
||||||
|
lora_request=None,
|
||||||
|
encoder_prompt=None,
|
||||||
|
encoder_prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_output_2 = CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="Hello",
|
||||||
|
token_ids=[100],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
|
||||||
|
request_output_2 = RequestOutput(
|
||||||
|
request_id="test-id",
|
||||||
|
prompt="Test prompt",
|
||||||
|
prompt_token_ids=[1, 2, 3],
|
||||||
|
prompt_logprobs=None,
|
||||||
|
outputs=[completion_output_2],
|
||||||
|
finished=True,
|
||||||
|
metrics=None,
|
||||||
|
lora_request=None,
|
||||||
|
encoder_prompt=None,
|
||||||
|
encoder_prompt_token_ids=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_generate(*args, **kwargs):
|
||||||
|
yield request_output_1
|
||||||
|
yield request_output_2
|
||||||
|
|
||||||
|
mock_engine.generate = MagicMock(side_effect=mock_generate)
|
||||||
|
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt="Test prompt",
|
||||||
|
max_tokens=10,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await serving_completion.create_completion(request)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk in response:
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
assert any("Internal server error" in chunk for chunk in chunks), (
|
||||||
|
f"Expected error message in chunks: {chunks}"
|
||||||
|
)
|
||||||
|
assert chunks[-1] == "data: [DONE]\n\n"
|
||||||
89
tests/entrypoints/openai/test_responses_error.py
Normal file
89
tests/entrypoints/openai/test_responses_error.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||||
|
from vllm.entrypoints.openai.serving_engine import GenerationError, OpenAIServing
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_raise_if_error_raises_generation_error():
|
||||||
|
"""test _raise_if_error raises GenerationError"""
|
||||||
|
# create a minimal OpenAIServing instance
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.model_config = MagicMock()
|
||||||
|
mock_engine.model_config.max_model_len = 100
|
||||||
|
mock_models = MagicMock()
|
||||||
|
|
||||||
|
serving = OpenAIServing(
|
||||||
|
engine_client=mock_engine,
|
||||||
|
models=mock_models,
|
||||||
|
request_logger=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# test that error finish_reason raises GenerationError
|
||||||
|
with pytest.raises(GenerationError) as exc_info:
|
||||||
|
serving._raise_if_error("error", "test-request-id")
|
||||||
|
|
||||||
|
assert str(exc_info.value) == "Internal server error"
|
||||||
|
assert exc_info.value.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
|
# test that other finish_reasons don't raise
|
||||||
|
serving._raise_if_error("stop", "test-request-id") # should not raise
|
||||||
|
serving._raise_if_error("length", "test-request-id") # should not raise
|
||||||
|
serving._raise_if_error(None, "test-request-id") # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_generation_error_to_response():
|
||||||
|
"""test _convert_generation_error_to_response creates proper ErrorResponse"""
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.model_config = MagicMock()
|
||||||
|
mock_engine.model_config.max_model_len = 100
|
||||||
|
mock_models = MagicMock()
|
||||||
|
|
||||||
|
serving = OpenAIServing(
|
||||||
|
engine_client=mock_engine,
|
||||||
|
models=mock_models,
|
||||||
|
request_logger=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create a GenerationError
|
||||||
|
gen_error = GenerationError("Internal server error")
|
||||||
|
|
||||||
|
# convert to ErrorResponse
|
||||||
|
error_response = serving._convert_generation_error_to_response(gen_error)
|
||||||
|
|
||||||
|
assert isinstance(error_response, ErrorResponse)
|
||||||
|
assert error_response.error.type == "InternalServerError"
|
||||||
|
assert error_response.error.message == "Internal server error"
|
||||||
|
assert error_response.error.code == HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_generation_error_to_streaming_response():
|
||||||
|
"""test _convert_generation_error_to_streaming_response output"""
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_engine.model_config = MagicMock()
|
||||||
|
mock_engine.model_config.max_model_len = 100
|
||||||
|
mock_models = MagicMock()
|
||||||
|
|
||||||
|
serving = OpenAIServing(
|
||||||
|
engine_client=mock_engine,
|
||||||
|
models=mock_models,
|
||||||
|
request_logger=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create a GenerationError
|
||||||
|
gen_error = GenerationError("Internal server error")
|
||||||
|
|
||||||
|
# convert to streaming error response
|
||||||
|
error_json = serving._convert_generation_error_to_streaming_response(gen_error)
|
||||||
|
|
||||||
|
assert isinstance(error_json, str)
|
||||||
|
assert "Internal server error" in error_json
|
||||||
|
assert "InternalServerError" in error_json
|
||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||||
from openai.types.responses.response_function_tool_call_output_item import (
|
from openai.types.responses.response_function_tool_call_output_item import (
|
||||||
ResponseFunctionToolCallOutputItem,
|
ResponseFunctionToolCallOutputItem,
|
||||||
)
|
)
|
||||||
@ -14,7 +15,8 @@ from openai.types.responses.response_reasoning_item import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from vllm.entrypoints.responses_utils import (
|
from vllm.entrypoints.responses_utils import (
|
||||||
construct_chat_message_with_tool_call,
|
_construct_single_message_from_response_item,
|
||||||
|
construct_chat_messages_with_tool_call,
|
||||||
convert_tool_responses_to_completions_format,
|
convert_tool_responses_to_completions_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -42,7 +44,43 @@ class TestResponsesUtils:
|
|||||||
|
|
||||||
assert result == {"type": "function", "function": input_tool}
|
assert result == {"type": "function", "function": input_tool}
|
||||||
|
|
||||||
def test_construct_chat_message_with_tool_call(self):
|
def test_construct_chat_messages_with_tool_call(self):
|
||||||
|
"""Test construction of chat messages with tool calls."""
|
||||||
|
reasoning_item = ResponseReasoningItem(
|
||||||
|
id="lol",
|
||||||
|
summary=[],
|
||||||
|
type="reasoning",
|
||||||
|
content=[
|
||||||
|
Content(
|
||||||
|
text="Leroy Jenkins",
|
||||||
|
type="reasoning_text",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
encrypted_content=None,
|
||||||
|
status=None,
|
||||||
|
)
|
||||||
|
mcp_tool_item = ResponseFunctionToolCall(
|
||||||
|
id="mcp_123",
|
||||||
|
call_id="call_123",
|
||||||
|
type="function_call",
|
||||||
|
status="completed",
|
||||||
|
name="python",
|
||||||
|
arguments='{"code": "123+456"}',
|
||||||
|
)
|
||||||
|
input_items = [reasoning_item, mcp_tool_item]
|
||||||
|
messages = construct_chat_messages_with_tool_call(input_items)
|
||||||
|
|
||||||
|
assert len(messages) == 1
|
||||||
|
message = messages[0]
|
||||||
|
assert message["role"] == "assistant"
|
||||||
|
assert message["reasoning"] == "Leroy Jenkins"
|
||||||
|
assert message["tool_calls"][0]["id"] == "call_123"
|
||||||
|
assert message["tool_calls"][0]["function"]["name"] == "python"
|
||||||
|
assert (
|
||||||
|
message["tool_calls"][0]["function"]["arguments"] == '{"code": "123+456"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_construct_single_message_from_response_item(self):
|
||||||
item = ResponseReasoningItem(
|
item = ResponseReasoningItem(
|
||||||
id="lol",
|
id="lol",
|
||||||
summary=[],
|
summary=[],
|
||||||
@ -56,7 +94,7 @@ class TestResponsesUtils:
|
|||||||
encrypted_content=None,
|
encrypted_content=None,
|
||||||
status=None,
|
status=None,
|
||||||
)
|
)
|
||||||
formatted_item = construct_chat_message_with_tool_call(item)
|
formatted_item = _construct_single_message_from_response_item(item)
|
||||||
assert formatted_item["role"] == "assistant"
|
assert formatted_item["role"] == "assistant"
|
||||||
assert formatted_item["reasoning"] == "Leroy Jenkins"
|
assert formatted_item["reasoning"] == "Leroy Jenkins"
|
||||||
|
|
||||||
@ -74,7 +112,7 @@ class TestResponsesUtils:
|
|||||||
status=None,
|
status=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_item = construct_chat_message_with_tool_call(item)
|
formatted_item = _construct_single_message_from_response_item(item)
|
||||||
assert formatted_item["role"] == "assistant"
|
assert formatted_item["role"] == "assistant"
|
||||||
assert (
|
assert (
|
||||||
formatted_item["reasoning"]
|
formatted_item["reasoning"]
|
||||||
@ -88,7 +126,7 @@ class TestResponsesUtils:
|
|||||||
output="1234",
|
output="1234",
|
||||||
status="completed",
|
status="completed",
|
||||||
)
|
)
|
||||||
formatted_item = construct_chat_message_with_tool_call(tool_call_output)
|
formatted_item = _construct_single_message_from_response_item(tool_call_output)
|
||||||
assert formatted_item["role"] == "tool"
|
assert formatted_item["role"] == "tool"
|
||||||
assert formatted_item["content"] == "1234"
|
assert formatted_item["content"] == "1234"
|
||||||
assert formatted_item["tool_call_id"] == "temp"
|
assert formatted_item["tool_call_id"] == "temp"
|
||||||
@ -102,7 +140,7 @@ class TestResponsesUtils:
|
|||||||
status=None,
|
status=None,
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
construct_chat_message_with_tool_call(item)
|
_construct_single_message_from_response_item(item)
|
||||||
|
|
||||||
output_item = ResponseOutputMessage(
|
output_item = ResponseOutputMessage(
|
||||||
id="msg_bf585bbbe3d500e0",
|
id="msg_bf585bbbe3d500e0",
|
||||||
@ -119,6 +157,6 @@ class TestResponsesUtils:
|
|||||||
type="message",
|
type="message",
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_item = construct_chat_message_with_tool_call(output_item)
|
formatted_item = _construct_single_message_from_response_item(output_item)
|
||||||
assert formatted_item["role"] == "assistant"
|
assert formatted_item["role"] == "assistant"
|
||||||
assert formatted_item["content"] == "dongyi"
|
assert formatted_item["content"] == "dongyi"
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import math
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
|
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
|
||||||
|
|
||||||
if not current_platform.is_cpu():
|
if not current_platform.is_cpu():
|
||||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||||
@ -36,6 +37,21 @@ SEQ_LENS = [ # (q_len, kv_len)
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_attn_isa(
|
||||||
|
block_size: int | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
):
|
||||||
|
if block_size and dtype:
|
||||||
|
return _get_attn_isa(dtype, block_size)
|
||||||
|
else:
|
||||||
|
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||||
|
return "neon"
|
||||||
|
elif torch._C._cpu._is_amx_tile_supported():
|
||||||
|
return "amx"
|
||||||
|
else:
|
||||||
|
return "vec"
|
||||||
|
|
||||||
|
|
||||||
# rand number generation takes too much time, cache rand tensors
|
# rand number generation takes too much time, cache rand tensors
|
||||||
@functools.lru_cache(maxsize=128, typed=False)
|
@functools.lru_cache(maxsize=128, typed=False)
|
||||||
def tensor_cache(
|
def tensor_cache(
|
||||||
@ -452,6 +468,49 @@ def test_varlen_with_paged_kv_normal_vec16(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||||
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
|
@pytest.mark.parametrize("block_size", [96, 128])
|
||||||
|
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS)
|
||||||
|
@pytest.mark.parametrize("dtype", QTYPES)
|
||||||
|
@pytest.mark.parametrize("soft_cap", [None])
|
||||||
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
|
@pytest.mark.parametrize("use_alibi", [False])
|
||||||
|
@pytest.mark.parametrize("use_sink", [False])
|
||||||
|
@pytest.mark.parametrize("isa", ["neon"])
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
current_platform.get_cpu_architecture() != CpuArchEnum.ARM,
|
||||||
|
reason="Not an Arm CPU.",
|
||||||
|
)
|
||||||
|
def test_varlen_with_paged_kv_normal_neon(
|
||||||
|
seq_lens: list[tuple[int, int]],
|
||||||
|
num_heads: tuple[int, int],
|
||||||
|
head_size: int,
|
||||||
|
sliding_window: int | None,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
block_size: int,
|
||||||
|
soft_cap: float | None,
|
||||||
|
num_blocks: int,
|
||||||
|
use_alibi: bool,
|
||||||
|
use_sink: bool,
|
||||||
|
isa: str,
|
||||||
|
) -> None:
|
||||||
|
varlen_with_paged_kv(
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_size=head_size,
|
||||||
|
sliding_window=sliding_window,
|
||||||
|
dtype=dtype,
|
||||||
|
block_size=block_size,
|
||||||
|
soft_cap=soft_cap,
|
||||||
|
num_blocks=num_blocks,
|
||||||
|
use_alibi=use_alibi,
|
||||||
|
use_sink=use_sink,
|
||||||
|
isa=isa,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
@pytest.mark.parametrize("seq_lens", SEQ_LENS)
|
||||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||||
@pytest.mark.parametrize("head_size", [96])
|
@pytest.mark.parametrize("head_size", [96])
|
||||||
@ -462,9 +521,7 @@ def test_varlen_with_paged_kv_normal_vec16(
|
|||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("use_alibi", [False])
|
@pytest.mark.parametrize("use_alibi", [False])
|
||||||
@pytest.mark.parametrize("use_sink", [False])
|
@pytest.mark.parametrize("use_sink", [False])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
|
||||||
)
|
|
||||||
def test_varlen_with_paged_kv_softcap(
|
def test_varlen_with_paged_kv_softcap(
|
||||||
seq_lens: list[tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
num_heads: tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
@ -503,9 +560,7 @@ def test_varlen_with_paged_kv_softcap(
|
|||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("use_alibi", [True])
|
@pytest.mark.parametrize("use_alibi", [True])
|
||||||
@pytest.mark.parametrize("use_sink", [False])
|
@pytest.mark.parametrize("use_sink", [False])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
|
||||||
)
|
|
||||||
def test_varlen_with_paged_kv_alibi(
|
def test_varlen_with_paged_kv_alibi(
|
||||||
seq_lens: list[tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
num_heads: tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
@ -544,9 +599,7 @@ def test_varlen_with_paged_kv_alibi(
|
|||||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||||
@pytest.mark.parametrize("use_alibi", [False])
|
@pytest.mark.parametrize("use_alibi", [False])
|
||||||
@pytest.mark.parametrize("use_sink", [True])
|
@pytest.mark.parametrize("use_sink", [True])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("isa", [get_attn_isa()])
|
||||||
"isa", ["amx"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
|
||||||
)
|
|
||||||
def test_varlen_with_paged_kv_sink(
|
def test_varlen_with_paged_kv_sink(
|
||||||
seq_lens: list[tuple[int, int]],
|
seq_lens: list[tuple[int, int]],
|
||||||
num_heads: tuple[int, int],
|
num_heads: tuple[int, int],
|
||||||
|
|||||||
@ -70,12 +70,12 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
|
|||||||
f"{torch.cuda.device_count()}"
|
f"{torch.cuda.device_count()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# `cuda_graph_sizes=[16]` to reduce load time.
|
# `cudagraph_capture_sizes=[16]` to reduce load time.
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_case.model_id,
|
model_case.model_id,
|
||||||
tensor_parallel_size=model_case.tp,
|
tensor_parallel_size=model_case.tp,
|
||||||
load_format="dummy",
|
load_format="dummy",
|
||||||
cuda_graph_sizes=[16],
|
cudagraph_capture_sizes=[16],
|
||||||
) as llm:
|
) as llm:
|
||||||
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
|
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
|
||||||
# def check_model(model):
|
# def check_model(model):
|
||||||
|
|||||||
@ -54,6 +54,10 @@ def setup_cuda():
|
|||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
current_platform.is_fp8_fnuz(),
|
||||||
|
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"num_tokens,d,dtype,group_size,seed",
|
"num_tokens,d,dtype,group_size,seed",
|
||||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
|
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
|
||||||
@ -78,14 +82,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
|||||||
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
factor_for_scale = 1e-2
|
factor_for_scale = 1e-2
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
fp8_info = torch.finfo(current_platform.fp8_dtype())
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
|
||||||
|
|
||||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(current_platform.fp8_dtype())
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
n_tiles = (N + block_n - 1) // block_n
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
@ -103,6 +107,9 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
|||||||
assert rel_diff < 0.001
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda(), reason="CUTLASS only supported on CUDA platform."
|
||||||
|
)
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_w8a8_block_fp8_cutlass_matmul():
|
def test_w8a8_block_fp8_cutlass_matmul():
|
||||||
# Test simple case where weight.shape % 128 != 0,
|
# Test simple case where weight.shape % 128 != 0,
|
||||||
@ -151,6 +158,10 @@ def test_w8a8_block_fp8_cutlass_matmul():
|
|||||||
assert rel_diff < 0.001
|
assert rel_diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
current_platform.is_fp8_fnuz(),
|
||||||
|
reason="This platform supports e4m3fnuz, not e4m3fn.",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"M,N,K,block_size,out_dtype,seed",
|
"M,N,K,block_size,out_dtype,seed",
|
||||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||||
|
|||||||
@ -15,6 +15,9 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
|
||||||
|
|
||||||
MNK_FACTORS = [
|
MNK_FACTORS = [
|
||||||
(1, 256, 128),
|
(1, 256, 128),
|
||||||
(1, 16384, 1024),
|
(1, 16384, 1024),
|
||||||
|
|||||||
@ -21,6 +21,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import ScalarType, scalar_types
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
pytest.skip("These tests use CUTLASS which requires CUDA", allow_module_level=True)
|
||||||
|
|
||||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||||
# unit tests to a common utility function. Currently the use of
|
# unit tests to a common utility function. Currently the use of
|
||||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||||
|
|||||||
@ -17,7 +17,6 @@ def test_idefics_multimodal(
|
|||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_name="HuggingFaceM4/Idefics3-8B-Llama3",
|
model_name="HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
task="classify",
|
|
||||||
convert="classify",
|
convert="classify",
|
||||||
load_format="dummy",
|
load_format="dummy",
|
||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
@ -86,7 +85,6 @@ def test_gemma_multimodal(
|
|||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_name="google/gemma-3-4b-it",
|
model_name="google/gemma-3-4b-it",
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
task="classify",
|
|
||||||
convert="classify",
|
convert="classify",
|
||||||
load_format="auto",
|
load_format="auto",
|
||||||
hf_overrides=update_config,
|
hf_overrides=update_config,
|
||||||
|
|||||||
@ -92,16 +92,19 @@ def run_test(
|
|||||||
*,
|
*,
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
distributed_executor_backend: str | None = None,
|
distributed_executor_backend: str | None = None,
|
||||||
|
dtype: str = "half",
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt_list = PROMPTS * 10
|
prompt_list = PROMPTS * 10
|
||||||
expected_list = EXPECTED[model] * 10
|
expected_list = EXPECTED[model] * 10
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
dtype="half",
|
dtype=dtype,
|
||||||
max_model_len=448,
|
max_model_len=448,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
# TODO (NickLucche) figure out output differences with non-eager and re-enable
|
||||||
|
enforce_eager=True,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
llm = vllm_model.llm
|
llm = vllm_model.llm
|
||||||
|
|
||||||
@ -120,12 +123,28 @@ def run_test(
|
|||||||
|
|
||||||
@pytest.mark.core_model
|
@pytest.mark.core_model
|
||||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_models(vllm_runner, model) -> None:
|
def test_models(vllm_runner, model, dtype) -> None:
|
||||||
run_test(
|
run_test(
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
model,
|
model,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cpu_model
|
||||||
|
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
def test_models_cpu(vllm_runner, model, dtype) -> None:
|
||||||
|
# @create_new_process_for_each_test() does not work for some runners
|
||||||
|
# TODO: to fix cpu privilege issues in run-cpu-test-arm.sh
|
||||||
|
run_test(
|
||||||
|
vllm_runner,
|
||||||
|
model,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
@ -186,6 +187,7 @@ async def test_fetch_image_error_conversion():
|
|||||||
connector.fetch_image(broken_img)
|
connector.fetch_image(broken_img)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(reruns=3, reruns_delay=5)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||||
@ -198,8 +200,12 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
video_sync, metadata_sync = connector.fetch_video(video_url)
|
try:
|
||||||
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
video_sync, metadata_sync = connector.fetch_video(video_url)
|
||||||
|
video_async, metadata_async = await connector.fetch_video_async(video_url)
|
||||||
|
except (TimeoutError, asyncio.TimeoutError) as e:
|
||||||
|
pytest.skip(f"Timeout fetching video (CI network flakiness): {e}")
|
||||||
|
|
||||||
assert np.array_equal(video_sync, video_async)
|
assert np.array_equal(video_sync, video_async)
|
||||||
assert metadata_sync == metadata_async
|
assert metadata_sync == metadata_async
|
||||||
|
|
||||||
|
|||||||
@ -147,7 +147,7 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
|
|||||||
"""
|
"""
|
||||||
Regression test for handling videos with broken frames.
|
Regression test for handling videos with broken frames.
|
||||||
This test uses a pre-corrupted video file (assets/corrupted.mp4) that
|
This test uses a pre-corrupted video file (assets/corrupted.mp4) that
|
||||||
contains broken/unreadable frames to verify the video loader handles
|
contains broken frames to verify the video loader handles
|
||||||
them gracefully without crashing and returns accurate metadata.
|
them gracefully without crashing and returns accurate metadata.
|
||||||
"""
|
"""
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
@ -177,3 +177,125 @@ def test_video_backend_handles_broken_frames(monkeypatch: pytest.MonkeyPatch):
|
|||||||
f"Expected fewer than {metadata['total_num_frames']} frames, "
|
f"Expected fewer than {metadata['total_num_frames']} frames, "
|
||||||
f"but loaded {frames.shape[0]} frames"
|
f"but loaded {frames.shape[0]} frames"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_1")
|
||||||
|
class TestVideoBackendOverride1(VideoLoader):
|
||||||
|
"""Test loader that returns FAKE_OUTPUT_1 to verify backend selection."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_bytes(
|
||||||
|
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||||
|
) -> tuple[npt.NDArray, dict]:
|
||||||
|
return FAKE_OUTPUT_1, {"video_backend": "test_video_backend_override_1"}
|
||||||
|
|
||||||
|
|
||||||
|
@VIDEO_LOADER_REGISTRY.register("test_video_backend_override_2")
|
||||||
|
class TestVideoBackendOverride2(VideoLoader):
|
||||||
|
"""Test loader that returns FAKE_OUTPUT_2 to verify backend selection."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_bytes(
|
||||||
|
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||||
|
) -> tuple[npt.NDArray, dict]:
|
||||||
|
return FAKE_OUTPUT_2, {"video_backend": "test_video_backend_override_2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_video_media_io_backend_kwarg_override(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""
|
||||||
|
Test that video_backend kwarg can override the VLLM_VIDEO_LOADER_BACKEND
|
||||||
|
environment variable.
|
||||||
|
|
||||||
|
This allows users to dynamically select a different video backend
|
||||||
|
via --media-io-kwargs without changing the global env var, which is
|
||||||
|
useful when plugins set a default backend but a specific request
|
||||||
|
needs a different one.
|
||||||
|
"""
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
# Set the env var to one backend
|
||||||
|
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_1")
|
||||||
|
|
||||||
|
imageio = ImageMediaIO()
|
||||||
|
|
||||||
|
# Without video_backend kwarg, should use env var backend
|
||||||
|
videoio_default = VideoMediaIO(imageio, num_frames=10)
|
||||||
|
frames_default, metadata_default = videoio_default.load_bytes(b"test")
|
||||||
|
np.testing.assert_array_equal(frames_default, FAKE_OUTPUT_1)
|
||||||
|
assert metadata_default["video_backend"] == "test_video_backend_override_1"
|
||||||
|
|
||||||
|
# With video_backend kwarg, should override env var
|
||||||
|
videoio_override = VideoMediaIO(
|
||||||
|
imageio, num_frames=10, video_backend="test_video_backend_override_2"
|
||||||
|
)
|
||||||
|
frames_override, metadata_override = videoio_override.load_bytes(b"test")
|
||||||
|
np.testing.assert_array_equal(frames_override, FAKE_OUTPUT_2)
|
||||||
|
assert metadata_override["video_backend"] == "test_video_backend_override_2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_video_media_io_backend_kwarg_not_passed_to_loader(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that video_backend kwarg is consumed by VideoMediaIO and NOT passed
|
||||||
|
through to the underlying video loader's load_bytes method.
|
||||||
|
|
||||||
|
This ensures the kwarg is properly popped from kwargs before forwarding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@VIDEO_LOADER_REGISTRY.register("test_reject_video_backend_kwarg")
|
||||||
|
class RejectVideoBackendKwargLoader(VideoLoader):
|
||||||
|
"""Test loader that fails if video_backend is passed through."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_bytes(
|
||||||
|
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||||
|
) -> tuple[npt.NDArray, dict]:
|
||||||
|
# This should never receive video_backend in kwargs
|
||||||
|
if "video_backend" in kwargs:
|
||||||
|
raise AssertionError(
|
||||||
|
"video_backend should be consumed by VideoMediaIO, "
|
||||||
|
"not passed to loader"
|
||||||
|
)
|
||||||
|
return FAKE_OUTPUT_1, {"received_kwargs": list(kwargs.keys())}
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_reject_video_backend_kwarg")
|
||||||
|
|
||||||
|
imageio = ImageMediaIO()
|
||||||
|
|
||||||
|
# Even when video_backend is provided, it should NOT be passed to loader
|
||||||
|
videoio = VideoMediaIO(
|
||||||
|
imageio,
|
||||||
|
num_frames=10,
|
||||||
|
video_backend="test_reject_video_backend_kwarg",
|
||||||
|
other_kwarg="should_pass_through",
|
||||||
|
)
|
||||||
|
|
||||||
|
# This should NOT raise AssertionError
|
||||||
|
frames, metadata = videoio.load_bytes(b"test")
|
||||||
|
np.testing.assert_array_equal(frames, FAKE_OUTPUT_1)
|
||||||
|
# Verify other kwargs are still passed through
|
||||||
|
assert "other_kwarg" in metadata["received_kwargs"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
"""
|
||||||
|
Test that when video_backend kwarg is None or not provided,
|
||||||
|
VideoMediaIO falls back to VLLM_VIDEO_LOADER_BACKEND env var.
|
||||||
|
"""
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "test_video_backend_override_2")
|
||||||
|
|
||||||
|
imageio = ImageMediaIO()
|
||||||
|
|
||||||
|
# Explicit None should fall back to env var
|
||||||
|
videoio_none = VideoMediaIO(imageio, num_frames=10, video_backend=None)
|
||||||
|
frames_none, metadata_none = videoio_none.load_bytes(b"test")
|
||||||
|
np.testing.assert_array_equal(frames_none, FAKE_OUTPUT_2)
|
||||||
|
assert metadata_none["video_backend"] == "test_video_backend_override_2"
|
||||||
|
|
||||||
|
# Not providing video_backend should also fall back to env var
|
||||||
|
videoio_missing = VideoMediaIO(imageio, num_frames=10)
|
||||||
|
frames_missing, metadata_missing = videoio_missing.load_bytes(b"test")
|
||||||
|
np.testing.assert_array_equal(frames_missing, FAKE_OUTPUT_2)
|
||||||
|
assert metadata_missing["video_backend"] == "test_video_backend_override_2"
|
||||||
|
|||||||
@ -10,10 +10,14 @@ import torch
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.quantization.fp8 import (
|
from vllm.model_executor.layers.quantization.fp8 import (
|
||||||
|
Fp8Config,
|
||||||
Fp8KVCacheMethod,
|
Fp8KVCacheMethod,
|
||||||
Fp8LinearMethod,
|
Fp8LinearMethod,
|
||||||
|
Fp8MoEMethod,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
@ -261,3 +265,87 @@ def test_scaled_fp8_quant(dtype) -> None:
|
|||||||
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
|
torch.narrow(y_nc_pad, 0, 0, x_nc.shape[0]), inv_scale_nc, dtype
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("method_cls", [Fp8LinearMethod, Fp8MoEMethod])
|
||||||
|
# FP8 weight reloading does not support online quantization
|
||||||
|
@pytest.mark.parametrize("is_checkpoint_fp8_serialized", [True]) # skip False
|
||||||
|
@pytest.mark.parametrize("weight_block_size", [None, [1, 1]])
|
||||||
|
# any postprocessing that is applied to the weights such as padding and repacking
|
||||||
|
# (excluding device sharding) must also be applied to the reloaded weights
|
||||||
|
#
|
||||||
|
# this is the case for marlin as well as per-tensor Fp8MoEMethod
|
||||||
|
@pytest.mark.parametrize("use_marlin", [False]) # skip True
|
||||||
|
def test_fp8_reloading(
|
||||||
|
method_cls, is_checkpoint_fp8_serialized, weight_block_size, use_marlin, dist_init
|
||||||
|
):
|
||||||
|
if is_checkpoint_fp8_serialized is False:
|
||||||
|
pytest.skip("FP8 weight reloading does not support online quantization")
|
||||||
|
|
||||||
|
if method_cls is Fp8MoEMethod and weight_block_size is None:
|
||||||
|
pytest.skip(
|
||||||
|
"FP8 Tensor weight reloading does not support fusing w13_weight_scale. "
|
||||||
|
"If this is your use case, consider using a restore function like #26327"
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.device("cuda:0"):
|
||||||
|
config = Fp8Config(
|
||||||
|
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||||
|
weight_block_size=weight_block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if method_cls is Fp8LinearMethod:
|
||||||
|
layer = torch.nn.Linear(1, 1)
|
||||||
|
method = method_cls(config)
|
||||||
|
method.create_weights(
|
||||||
|
layer=layer,
|
||||||
|
input_size_per_partition=1,
|
||||||
|
output_partition_sizes=[1],
|
||||||
|
input_size=1,
|
||||||
|
output_size=1,
|
||||||
|
params_dtype=torch.bfloat16,
|
||||||
|
weight_loader=default_weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
layer = FusedMoE(
|
||||||
|
num_experts=1,
|
||||||
|
top_k=1,
|
||||||
|
hidden_size=1,
|
||||||
|
intermediate_size=1,
|
||||||
|
)
|
||||||
|
method = method_cls(config, layer)
|
||||||
|
method.create_weights(
|
||||||
|
layer=layer,
|
||||||
|
num_experts=1,
|
||||||
|
hidden_size=1,
|
||||||
|
intermediate_size_per_partition=1,
|
||||||
|
params_dtype=torch.bfloat16,
|
||||||
|
weight_loader=default_weight_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
method.use_marlin = use_marlin
|
||||||
|
|
||||||
|
# capture weights format during loading
|
||||||
|
original_metadata = [
|
||||||
|
(name, param.shape, getattr(param, "weight_loader", default_weight_loader))
|
||||||
|
for name, param in layer.named_parameters()
|
||||||
|
]
|
||||||
|
|
||||||
|
# test loading
|
||||||
|
for name, shape, _ in original_metadata:
|
||||||
|
param = getattr(layer, name)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||||
|
|
||||||
|
method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
# test reloading works after loading
|
||||||
|
# assuming that no reshaping occurred
|
||||||
|
for name, shape, original_weight_loader in original_metadata:
|
||||||
|
param = getattr(layer, name)
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
assert weight_loader is original_weight_loader
|
||||||
|
weight_loader(param, torch.zeros(shape)) # cannot use empty
|
||||||
|
|
||||||
|
method.process_weights_after_loading(layer)
|
||||||
|
|||||||
@ -212,11 +212,11 @@ def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
|
|||||||
task = "wikitext"
|
task = "wikitext"
|
||||||
rtol = 0.1
|
rtol = 0.1
|
||||||
|
|
||||||
# Smaller cuda_graph_sizes to speed up the test.
|
# Smaller cudagraph_capture_sizes to speed up the test.
|
||||||
results = lm_eval.simple_evaluate(
|
results = lm_eval.simple_evaluate(
|
||||||
model="vllm",
|
model="vllm",
|
||||||
model_args=config.get_model_args(
|
model_args=config.get_model_args(
|
||||||
tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
|
tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]}
|
||||||
),
|
),
|
||||||
tasks=task,
|
tasks=task,
|
||||||
batch_size=64,
|
batch_size=64,
|
||||||
|
|||||||
195
tests/reasoning/test_minimax_m2_append_reasoning_parser.py
Normal file
195
tests/reasoning/test_minimax_m2_append_reasoning_parser.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
parser_name = "minimax_m2_append_think"
|
||||||
|
end_token = "</think>"
|
||||||
|
|
||||||
|
# MiniMax M2 model path
|
||||||
|
REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def minimax_m2_tokenizer():
|
||||||
|
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MiniMaxM2AppendThinkReasoningParser behavior:
|
||||||
|
# - Prepends <think> to the beginning of the output
|
||||||
|
# - Does NOT separate reasoning and content
|
||||||
|
# - Returns everything as content (with <think> prepended)
|
||||||
|
# - reasoning is always None
|
||||||
|
#
|
||||||
|
# This parser is used when you want to keep the raw output with <think> added
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Case: simple output with end token
|
||||||
|
SIMPLE_OUTPUT = {
|
||||||
|
"output": "This is reasoning</think>This is response",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": "<think>This is reasoning</think>This is response",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: output without end token (reasoning in progress)
|
||||||
|
NO_END_TOKEN = {
|
||||||
|
"output": "This is reasoning in progress",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": "<think>This is reasoning in progress",
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: only end token
|
||||||
|
ONLY_END_TOKEN = {
|
||||||
|
"output": "</think>This is response",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": "<think></think>This is response",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: multiple lines
|
||||||
|
MULTIPLE_LINES = {
|
||||||
|
"output": "Line 1\nLine 2</think>Response 1\nResponse 2",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": "<think>Line 1\nLine 2</think>Response 1\nResponse 2",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: empty output (non-streaming prepends <think>)
|
||||||
|
EMPTY = {
|
||||||
|
"output": "",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": "<think>",
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: empty output streaming (no tokens = no output)
|
||||||
|
EMPTY_STREAMING = {
|
||||||
|
"output": "",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: special characters
|
||||||
|
SPECIAL_CHARS = {
|
||||||
|
"output": "Let me think... 1+1=2</think>Yes!",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": "<think>Let me think... 1+1=2</think>Yes!",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: code in output
|
||||||
|
CODE_OUTPUT = {
|
||||||
|
"output": "```python\nprint('hi')\n```</think>Here's the code.",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": "<think>```python\nprint('hi')\n```</think>Here's the code.",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SIMPLE_OUTPUT,
|
||||||
|
id="simple_output",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SIMPLE_OUTPUT,
|
||||||
|
id="simple_output_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
NO_END_TOKEN,
|
||||||
|
id="no_end_token",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
NO_END_TOKEN,
|
||||||
|
id="no_end_token_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
ONLY_END_TOKEN,
|
||||||
|
id="only_end_token",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
ONLY_END_TOKEN,
|
||||||
|
id="only_end_token_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
EMPTY,
|
||||||
|
id="empty",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
EMPTY_STREAMING,
|
||||||
|
id="empty_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SPECIAL_CHARS,
|
||||||
|
id="special_chars",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SPECIAL_CHARS,
|
||||||
|
id="special_chars_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
CODE_OUTPUT,
|
||||||
|
id="code_output",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
CODE_OUTPUT,
|
||||||
|
id="code_output_streaming",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||||
|
def test_reasoning(
|
||||||
|
streaming: bool,
|
||||||
|
param_dict: dict,
|
||||||
|
minimax_m2_tokenizer,
|
||||||
|
):
|
||||||
|
output = minimax_m2_tokenizer.tokenize(param_dict["output"])
|
||||||
|
# decode everything to tokens
|
||||||
|
output_tokens: list[str] = [
|
||||||
|
minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output
|
||||||
|
]
|
||||||
|
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||||
|
minimax_m2_tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, output_tokens, streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
assert reasoning == param_dict["reasoning"]
|
||||||
|
assert content == param_dict["content"]
|
||||||
|
|
||||||
|
# Test is_reasoning_end
|
||||||
|
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
|
||||||
|
is_reasoning_end = parser.is_reasoning_end(output_ids)
|
||||||
|
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||||
230
tests/reasoning/test_minimax_m2_reasoning_parser.py
Normal file
230
tests/reasoning/test_minimax_m2_reasoning_parser.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
|
|
||||||
|
parser_name = "minimax_m2"
|
||||||
|
end_token = "</think>"
|
||||||
|
|
||||||
|
# MiniMax M2 model path
|
||||||
|
REASONING_MODEL_NAME = "MiniMaxAI/MiniMax-M2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def minimax_m2_tokenizer():
|
||||||
|
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MiniMax M2 specific behavior:
|
||||||
|
# - Model does NOT generate <think> start token
|
||||||
|
# - Model only generates </think> end token
|
||||||
|
# - All content before </think> is reasoning
|
||||||
|
# - All content after </think> is the actual response (content)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Case: reasoning + end token + content (typical case)
|
||||||
|
SIMPLE_REASONING = {
|
||||||
|
"output": "This is a reasoning section</think>This is the rest",
|
||||||
|
"reasoning": "This is a reasoning section",
|
||||||
|
"content": "This is the rest",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: reasoning + end token only (no content after)
|
||||||
|
COMPLETE_REASONING = {
|
||||||
|
"output": "This is a reasoning section</think>",
|
||||||
|
"reasoning": "This is a reasoning section",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: no end token yet (streaming in progress, all is reasoning)
|
||||||
|
NO_END_TOKEN = {
|
||||||
|
"output": "This is reasoning in progress",
|
||||||
|
"reasoning": "This is reasoning in progress",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: multiple lines of reasoning
|
||||||
|
MULTIPLE_LINES = {
|
||||||
|
"output": "First line\nSecond line</think>Response first line\nResponse second",
|
||||||
|
"reasoning": "First line\nSecond line",
|
||||||
|
"content": "Response first line\nResponse second",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: only end token (empty reasoning, immediate response)
|
||||||
|
SHORTEST_REASONING_NO_STREAMING = {
|
||||||
|
"output": "</think>This is the response",
|
||||||
|
"reasoning": "",
|
||||||
|
"content": "This is the response",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: only end token streaming (reasoning is None because it's just the token)
|
||||||
|
SHORTEST_REASONING_STREAMING = {
|
||||||
|
"output": "</think>This is the response",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": "This is the response",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: empty output
|
||||||
|
EMPTY = {
|
||||||
|
"output": "",
|
||||||
|
"reasoning": "",
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: empty streaming
|
||||||
|
EMPTY_STREAMING = {
|
||||||
|
"output": "",
|
||||||
|
"reasoning": None,
|
||||||
|
"content": None,
|
||||||
|
"is_reasoning_end": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: long reasoning with special characters
|
||||||
|
SPECIAL_CHARS = {
|
||||||
|
"output": "Let me think... 1+1=2, right?</think>Yes, 1+1=2.",
|
||||||
|
"reasoning": "Let me think... 1+1=2, right?",
|
||||||
|
"content": "Yes, 1+1=2.",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Case: reasoning with code blocks
|
||||||
|
CODE_IN_REASONING = {
|
||||||
|
"output": "```python\nprint('hello')\n```</think>Here is the code.",
|
||||||
|
"reasoning": "```python\nprint('hello')\n```",
|
||||||
|
"content": "Here is the code.",
|
||||||
|
"is_reasoning_end": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
# Core cases: no start token (MiniMax M2 actual behavior)
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SIMPLE_REASONING,
|
||||||
|
id="simple_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SIMPLE_REASONING,
|
||||||
|
id="simple_reasoning_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
COMPLETE_REASONING,
|
||||||
|
id="complete_reasoning_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
NO_END_TOKEN,
|
||||||
|
id="no_end_token",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
NO_END_TOKEN,
|
||||||
|
id="no_end_token_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
MULTIPLE_LINES,
|
||||||
|
id="multiple_lines_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SHORTEST_REASONING_NO_STREAMING,
|
||||||
|
id="shortest_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SHORTEST_REASONING_STREAMING,
|
||||||
|
id="shortest_reasoning_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
EMPTY,
|
||||||
|
id="empty",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
EMPTY_STREAMING,
|
||||||
|
id="empty_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
SPECIAL_CHARS,
|
||||||
|
id="special_chars",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
SPECIAL_CHARS,
|
||||||
|
id="special_chars_streaming",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
False,
|
||||||
|
CODE_IN_REASONING,
|
||||||
|
id="code_in_reasoning",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
True,
|
||||||
|
CODE_IN_REASONING,
|
||||||
|
id="code_in_reasoning_streaming",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||||
|
def test_reasoning(
|
||||||
|
streaming: bool,
|
||||||
|
param_dict: dict,
|
||||||
|
minimax_m2_tokenizer,
|
||||||
|
):
|
||||||
|
output = minimax_m2_tokenizer.tokenize(param_dict["output"])
|
||||||
|
# decode everything to tokens
|
||||||
|
output_tokens: list[str] = [
|
||||||
|
minimax_m2_tokenizer.convert_tokens_to_string([token]) for token in output
|
||||||
|
]
|
||||||
|
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||||
|
minimax_m2_tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
reasoning, content = run_reasoning_extraction(
|
||||||
|
parser, output_tokens, streaming=streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
assert reasoning == param_dict["reasoning"]
|
||||||
|
assert content == param_dict["content"]
|
||||||
|
|
||||||
|
# Test is_reasoning_end
|
||||||
|
output_ids = minimax_m2_tokenizer.convert_tokens_to_ids(output)
|
||||||
|
is_reasoning_end = parser.is_reasoning_end(output_ids)
|
||||||
|
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||||
|
|
||||||
|
# Test extract_content
|
||||||
|
if param_dict["content"] is not None:
|
||||||
|
content = parser.extract_content_ids(output_ids)
|
||||||
|
assert content == minimax_m2_tokenizer.convert_tokens_to_ids(
|
||||||
|
minimax_m2_tokenizer.tokenize(param_dict["content"])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
content = parser.extract_content_ids(output)
|
||||||
|
assert content == []
|
||||||
@ -89,64 +89,6 @@ def test_update_config():
|
|||||||
new_config3 = update_config(config3, {"a": "new_value"})
|
new_config3 = update_config(config3, {"a": "new_value"})
|
||||||
|
|
||||||
|
|
||||||
# Can remove once --task option is fully deprecated
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
|
||||||
[
|
|
||||||
("distilbert/distilgpt2", "generate", "none", "generate"),
|
|
||||||
("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
|
|
||||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
|
||||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none", "classify"),
|
|
||||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "embed"),
|
|
||||||
("openai/whisper-small", "generate", "none", "transcription"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_auto_task(
|
|
||||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
|
||||||
):
|
|
||||||
config = ModelConfig(model_id, task="auto")
|
|
||||||
|
|
||||||
assert config.runner_type == expected_runner_type
|
|
||||||
assert config.convert_type == expected_convert_type
|
|
||||||
|
|
||||||
|
|
||||||
# Can remove once --task option is fully deprecated
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
|
||||||
[
|
|
||||||
("distilbert/distilgpt2", "pooling", "embed", "embed"),
|
|
||||||
("intfloat/multilingual-e5-small", "pooling", "embed", "embed"),
|
|
||||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
|
||||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify", "classify"),
|
|
||||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed", "embed"),
|
|
||||||
("openai/whisper-small", "pooling", "embed", "embed"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_score_task(
|
|
||||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
|
||||||
):
|
|
||||||
config = ModelConfig(model_id, task="score")
|
|
||||||
|
|
||||||
assert config.runner_type == expected_runner_type
|
|
||||||
assert config.convert_type == expected_convert_type
|
|
||||||
|
|
||||||
|
|
||||||
# Can remove once --task option is fully deprecated
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("model_id", "expected_runner_type", "expected_convert_type", "expected_task"),
|
|
||||||
[
|
|
||||||
("openai/whisper-small", "generate", "none", "transcription"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_transcription_task(
|
|
||||||
model_id, expected_runner_type, expected_convert_type, expected_task
|
|
||||||
):
|
|
||||||
config = ModelConfig(model_id, task="transcription")
|
|
||||||
|
|
||||||
assert config.runner_type == expected_runner_type
|
|
||||||
assert config.convert_type == expected_convert_type
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_id", "expected_runner_type", "expected_convert_type"),
|
("model_id", "expected_runner_type", "expected_convert_type"),
|
||||||
[
|
[
|
||||||
@ -1085,7 +1027,7 @@ def test_vllm_config_explicit_overrides():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Override one field but not others
|
# Override one field but not others
|
||||||
pass_config = PassConfig(enable_noop=False)
|
pass_config = PassConfig(eliminate_noops=False)
|
||||||
compilation_config = CompilationConfig(pass_config=pass_config)
|
compilation_config = CompilationConfig(pass_config=pass_config)
|
||||||
config = VllmConfig(
|
config = VllmConfig(
|
||||||
model_config=regular_model,
|
model_config=regular_model,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import pytest
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.envs import (
|
from vllm.envs import (
|
||||||
|
disable_envs_cache,
|
||||||
enable_envs_cache,
|
enable_envs_cache,
|
||||||
env_list_with_choices,
|
env_list_with_choices,
|
||||||
env_set_with_choices,
|
env_set_with_choices,
|
||||||
@ -57,6 +58,43 @@ def test_getattr_with_cache(monkeypatch: pytest.MonkeyPatch):
|
|||||||
envs.__getattr__ = envs.__getattr__.__wrapped__
|
envs.__getattr__ = envs.__getattr__.__wrapped__
|
||||||
|
|
||||||
|
|
||||||
|
def test_getattr_with_reset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setenv("VLLM_HOST_IP", "1.1.1.1")
|
||||||
|
# __getattr__ is not decorated with functools.cache
|
||||||
|
assert not hasattr(envs.__getattr__, "cache_info")
|
||||||
|
|
||||||
|
# Enable envs cache and ignore ongoing environment changes
|
||||||
|
enable_envs_cache()
|
||||||
|
assert envs.VLLM_HOST_IP == "1.1.1.1"
|
||||||
|
# With cache enabled, the environment variable value is cached and unchanged
|
||||||
|
monkeypatch.setenv("VLLM_HOST_IP", "2.2.2.2")
|
||||||
|
assert envs.VLLM_HOST_IP == "1.1.1.1"
|
||||||
|
|
||||||
|
disable_envs_cache()
|
||||||
|
assert envs.VLLM_HOST_IP == "2.2.2.2"
|
||||||
|
# After cache disabled, the environment variable value would be synced
|
||||||
|
# with os.environ
|
||||||
|
monkeypatch.setenv("VLLM_HOST_IP", "3.3.3.3")
|
||||||
|
assert envs.VLLM_HOST_IP == "3.3.3.3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_envs_cache_enabled() -> None:
|
||||||
|
assert not envs._is_envs_cache_enabled()
|
||||||
|
enable_envs_cache()
|
||||||
|
assert envs._is_envs_cache_enabled()
|
||||||
|
|
||||||
|
# Only wrap one-layer of cache, so we only need to
|
||||||
|
# call disable once to reset.
|
||||||
|
enable_envs_cache()
|
||||||
|
enable_envs_cache()
|
||||||
|
enable_envs_cache()
|
||||||
|
disable_envs_cache()
|
||||||
|
assert not envs._is_envs_cache_enabled()
|
||||||
|
|
||||||
|
disable_envs_cache()
|
||||||
|
assert not envs._is_envs_cache_enabled()
|
||||||
|
|
||||||
|
|
||||||
class TestEnvWithChoices:
|
class TestEnvWithChoices:
|
||||||
"""Test cases for env_with_choices function."""
|
"""Test cases for env_with_choices function."""
|
||||||
|
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class RemoteOpenAIServer:
|
|||||||
vllm_serve_args: list[str],
|
vllm_serve_args: list[str],
|
||||||
*,
|
*,
|
||||||
env_dict: dict[str, str] | None = None,
|
env_dict: dict[str, str] | None = None,
|
||||||
seed: int | None = 0,
|
seed: int = 0,
|
||||||
auto_port: bool = True,
|
auto_port: bool = True,
|
||||||
max_wait_seconds: float | None = None,
|
max_wait_seconds: float | None = None,
|
||||||
override_hf_configs: dict[str, Any] | None = None,
|
override_hf_configs: dict[str, Any] | None = None,
|
||||||
@ -283,7 +283,7 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer):
|
|||||||
child_process_fxn: Callable[[dict[str, str] | None, str, list[str]], None],
|
child_process_fxn: Callable[[dict[str, str] | None, str, list[str]], None],
|
||||||
*,
|
*,
|
||||||
env_dict: dict[str, str] | None = None,
|
env_dict: dict[str, str] | None = None,
|
||||||
seed: int | None = 0,
|
seed: int = 0,
|
||||||
auto_port: bool = True,
|
auto_port: bool = True,
|
||||||
max_wait_seconds: float | None = None,
|
max_wait_seconds: float | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -106,8 +106,8 @@ def create_common_attn_metadata(
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_cpu=seq_lens_cpu,
|
_seq_lens_cpu=seq_lens_cpu,
|
||||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||||
num_reqs=batch_spec.batch_size,
|
num_reqs=batch_spec.batch_size,
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch._dynamo.config as dynamo_config
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sampling_params import StructuredOutputsParams
|
from vllm.sampling_params import StructuredOutputsParams
|
||||||
from vllm.v1.metrics.reader import Metric
|
from vllm.v1.metrics.reader import Metric
|
||||||
|
|
||||||
@ -70,6 +71,18 @@ def test_without_spec_decoding(
|
|||||||
(True, "uni", True, None, True),
|
(True, "uni", True, None, True),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
# On ROCm, Only test with structured_outputs (deterministic)
|
||||||
|
# and skip chunk_prefill (more variable).
|
||||||
|
test_configs = [
|
||||||
|
cfg
|
||||||
|
for cfg in test_configs
|
||||||
|
if not cfg[4] # skip chunk_prefill=True
|
||||||
|
]
|
||||||
|
test_sampling_params = [
|
||||||
|
p for p in test_sampling_params if p.get("structured_outputs") is not None
|
||||||
|
]
|
||||||
|
|
||||||
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
|
run_tests(monkeypatch, MODEL, test_configs, test_sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@ -108,7 +121,14 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
|||||||
(True, "uni", True, spec_config_short, True),
|
(True, "uni", True, spec_config_short, True),
|
||||||
]
|
]
|
||||||
|
|
||||||
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
|
# On ROCm, use TRITON_ATTN + float32 for better numerical consistency
|
||||||
|
run_tests(
|
||||||
|
monkeypatch,
|
||||||
|
MTP_MODEL,
|
||||||
|
test_configs,
|
||||||
|
test_sampling_params,
|
||||||
|
is_testing_with_spec_decoding=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dynamo_config.patch(cache_size_limit=16)
|
@dynamo_config.patch(cache_size_limit=16)
|
||||||
@ -117,15 +137,23 @@ def run_tests(
|
|||||||
model: str,
|
model: str,
|
||||||
test_configs: list[tuple],
|
test_configs: list[tuple],
|
||||||
test_sampling_params: list[dict[str, Any]],
|
test_sampling_params: list[dict[str, Any]],
|
||||||
|
is_testing_with_spec_decoding: bool = False,
|
||||||
):
|
):
|
||||||
"""Test consistency of combos of async scheduling, preemption,
|
"""Test consistency of combos of async scheduling, preemption,
|
||||||
uni/multiproc executor with spec decoding."""
|
uni/multiproc executor with spec decoding."""
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
# avoid precision errors
|
# avoid precision errors
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
if current_platform.is_rocm():
|
||||||
# lock matmul precision to full FP32
|
if is_testing_with_spec_decoding:
|
||||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "highest")
|
# Use TRITON_ATTN for spec decoding test for consistency
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||||
|
else:
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
|
||||||
|
else:
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||||
|
# lock matmul precision to full FP32 (IEEE)
|
||||||
|
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
|
||||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||||
outputs: list[tuple[str, list, list]] = []
|
outputs: list[tuple[str, list, list]] = []
|
||||||
for n, (
|
for n, (
|
||||||
@ -145,6 +173,7 @@ def run_tests(
|
|||||||
async_scheduling,
|
async_scheduling,
|
||||||
spec_config,
|
spec_config,
|
||||||
test_prefill_chunking=test_prefill_chunking,
|
test_prefill_chunking=test_prefill_chunking,
|
||||||
|
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
|
||||||
)
|
)
|
||||||
outputs.append(test_results)
|
outputs.append(test_results)
|
||||||
|
|
||||||
@ -174,17 +203,34 @@ def run_tests(
|
|||||||
name_0=f"baseline=[{baseline_config}], params={params}",
|
name_0=f"baseline=[{baseline_config}], params={params}",
|
||||||
name_1=f"config=[{test_config}], params={params}",
|
name_1=f"config=[{test_config}], params={params}",
|
||||||
)
|
)
|
||||||
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
|
||||||
|
# On ROCm with TRITON_ATTN (spec decoding test), skip strict
|
||||||
|
# logprobs comparison when logprobs are requested
|
||||||
|
skip_logprobs_check = (
|
||||||
|
current_platform.is_rocm()
|
||||||
|
and params.get("logprobs")
|
||||||
|
and is_testing_with_spec_decoding
|
||||||
|
)
|
||||||
|
if not skip_logprobs_check:
|
||||||
|
assert _all_logprobs_match(base_logprobs, test_logprobs)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
base_acceptance_rate is not None
|
base_acceptance_rate is not None
|
||||||
and test_acceptance_rate is not None
|
and test_acceptance_rate is not None
|
||||||
):
|
):
|
||||||
if "spec_mml=None" in test_config:
|
if "spec_mml=None" in test_config:
|
||||||
|
# Preemption causes more variance in acceptance rates
|
||||||
|
if (
|
||||||
|
current_platform.is_rocm()
|
||||||
|
and "preemption=True" in test_config
|
||||||
|
):
|
||||||
|
tolerance = 0.10
|
||||||
|
else:
|
||||||
|
tolerance = 0.05
|
||||||
assert (
|
assert (
|
||||||
test_acceptance_rate > base_acceptance_rate
|
test_acceptance_rate > base_acceptance_rate
|
||||||
or test_acceptance_rate
|
or test_acceptance_rate
|
||||||
== pytest.approx(base_acceptance_rate, rel=5e-2)
|
== pytest.approx(base_acceptance_rate, rel=tolerance)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Currently the reported acceptance rate is expected to be
|
# Currently the reported acceptance rate is expected to be
|
||||||
@ -215,6 +261,7 @@ def run_test(
|
|||||||
async_scheduling: bool,
|
async_scheduling: bool,
|
||||||
spec_config: dict[str, Any] | None,
|
spec_config: dict[str, Any] | None,
|
||||||
test_prefill_chunking: bool,
|
test_prefill_chunking: bool,
|
||||||
|
is_testing_with_spec_decoding: bool = False,
|
||||||
):
|
):
|
||||||
spec_decoding = spec_config is not None
|
spec_decoding = spec_config is not None
|
||||||
cache_arg: dict[str, Any] = (
|
cache_arg: dict[str, Any] = (
|
||||||
@ -233,6 +280,15 @@ def run_test(
|
|||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
print(f"---- TESTING {test_str}: {test_config}")
|
print(f"---- TESTING {test_str}: {test_config}")
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
|
# On ROCm: use float16 for first test (ROCM_AITER_FA), but float32 for
|
||||||
|
# spec decoding test (TRITON_ATTN) for better precision.
|
||||||
|
# On others: always use float32.
|
||||||
|
if current_platform.is_rocm() and not is_testing_with_spec_decoding:
|
||||||
|
dtype = "float16"
|
||||||
|
else:
|
||||||
|
dtype = "float32"
|
||||||
|
|
||||||
with VllmRunner(
|
with VllmRunner(
|
||||||
model,
|
model,
|
||||||
max_model_len=512,
|
max_model_len=512,
|
||||||
@ -242,7 +298,7 @@ def run_test(
|
|||||||
# enforce_eager=True,
|
# enforce_eager=True,
|
||||||
async_scheduling=async_scheduling,
|
async_scheduling=async_scheduling,
|
||||||
distributed_executor_backend=executor,
|
distributed_executor_backend=executor,
|
||||||
dtype="float32", # avoid precision errors
|
dtype=dtype,
|
||||||
speculative_config=spec_config,
|
speculative_config=spec_config,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
**cache_arg,
|
**cache_arg,
|
||||||
@ -302,11 +358,21 @@ def _all_logprobs_match(req_a, req_b) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
|
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
|
||||||
return len(lps_a) == len(lps_b) and all(
|
if current_platform.is_rocm():
|
||||||
a.decoded_token == b.decoded_token
|
# ROCm has higher numerical variance
|
||||||
and a.rank == b.rank
|
# due to use of float16.
|
||||||
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
|
rel_tol, abs_tol = 5e-2, 1e-5
|
||||||
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
|
else:
|
||||||
|
rel_tol, abs_tol = 1e-3, 1e-6
|
||||||
|
return (
|
||||||
|
len(lps_a) == len(lps_b)
|
||||||
|
and lps_a.keys() == lps_b.keys()
|
||||||
|
and all(
|
||||||
|
a.decoded_token == b.decoded_token
|
||||||
|
and a.rank == b.rank
|
||||||
|
and a.logprob == pytest.approx(b.logprob, rel=rel_tol, abs=abs_tol)
|
||||||
|
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
131
tests/v1/e2e/test_async_spec_decode.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Test that verifies no implicit GPU-CPU synchronization occurs during
|
||||||
|
speculative decoding generation under expected conditions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sync_tracker():
|
||||||
|
"""
|
||||||
|
Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
|
||||||
|
lazy init syncs. Prints stack traces immediately when syncs occur.
|
||||||
|
"""
|
||||||
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
|
|
||||||
|
# Shared counter for cross-process communication (inherited by fork)
|
||||||
|
sync_count = multiprocessing.Value("i", 0)
|
||||||
|
|
||||||
|
# Save original property
|
||||||
|
original_prop = CommonAttentionMetadata.seq_lens_cpu
|
||||||
|
original_fget = original_prop.fget
|
||||||
|
|
||||||
|
# Create tracking wrapper
|
||||||
|
def tracking_seq_lens_cpu(self):
|
||||||
|
if self._seq_lens_cpu is None:
|
||||||
|
# Increment counter
|
||||||
|
with sync_count.get_lock():
|
||||||
|
sync_count.value += 1
|
||||||
|
count = sync_count.value
|
||||||
|
# Print stack trace immediately (shows in subprocess output)
|
||||||
|
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||||
|
print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
|
||||||
|
print(f"{'=' * 60}", file=sys.stderr)
|
||||||
|
traceback.print_stack(file=sys.stderr)
|
||||||
|
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||||
|
sys.stderr.flush()
|
||||||
|
return original_fget(self)
|
||||||
|
|
||||||
|
# Apply patch
|
||||||
|
CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)
|
||||||
|
|
||||||
|
class SyncTracker:
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
return sync_count.value
|
||||||
|
|
||||||
|
def assert_no_sync(self, msg: str = ""):
|
||||||
|
count = sync_count.value
|
||||||
|
assert count == 0, (
|
||||||
|
f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
|
||||||
|
f"{count} times. See stack traces above. {msg}"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield SyncTracker()
|
||||||
|
|
||||||
|
# Restore original property
|
||||||
|
CommonAttentionMetadata.seq_lens_cpu = original_prop
|
||||||
|
torch._dynamo.reset()
|
||||||
|
|
||||||
|
|
||||||
|
# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
|
||||||
|
SPEC_DECODE_CONFIGS = [
|
||||||
|
pytest.param(
|
||||||
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
"nm-testing/Llama3_2_1B_speculator.eagle3",
|
||||||
|
"eagle3",
|
||||||
|
2,
|
||||||
|
id="eagle3-llama",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
"eagle618/deepseek-v3-random",
|
||||||
|
"eagle618/eagle-deepseek-v3-random",
|
||||||
|
"eagle",
|
||||||
|
2,
|
||||||
|
id="eagle-mla-deepseek",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model,spec_model,method,num_spec_tokens",
|
||||||
|
SPEC_DECODE_CONFIGS,
|
||||||
|
)
|
||||||
|
def test_no_sync_with_spec_decode(
|
||||||
|
sync_tracker,
|
||||||
|
model: str,
|
||||||
|
spec_model: str,
|
||||||
|
method: str,
|
||||||
|
num_spec_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that no implicit GPU-CPU sync occurs during speculative decoding
|
||||||
|
generation.
|
||||||
|
"""
|
||||||
|
# Import vLLM AFTER sync_tracker fixture has applied the patch
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
max_model_len=256,
|
||||||
|
speculative_config={
|
||||||
|
"method": method,
|
||||||
|
"num_speculative_tokens": num_spec_tokens,
|
||||||
|
"model": spec_model,
|
||||||
|
},
|
||||||
|
enforce_eager=True,
|
||||||
|
async_scheduling=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
["Hello, my name is"],
|
||||||
|
SamplingParams(temperature=0, max_tokens=10),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert len(outputs[0].outputs[0].text) > 0
|
||||||
|
|
||||||
|
del llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
sync_tracker.assert_no_sync()
|
||||||
@ -191,8 +191,8 @@ def test_suffix_decoding_acceptance(
|
|||||||
# Expect the acceptance rate to improve.
|
# Expect the acceptance rate to improve.
|
||||||
assert first_accept_rate < last_accept_rate
|
assert first_accept_rate < last_accept_rate
|
||||||
|
|
||||||
# Heuristic: expect at least 82.5% acceptance rate at the end.
|
# Heuristic: expect at least 80.0% acceptance rate at the end.
|
||||||
assert last_accept_rate > 0.825
|
assert last_accept_rate > 0.80
|
||||||
|
|
||||||
del spec_llm
|
del spec_llm
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
54
tests/v1/engine/test_init_error_messaging.py
Normal file
54
tests/v1/engine/test_init_error_messaging.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.v1.core.kv_cache_utils import check_enough_kv_cache_memory
|
||||||
|
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||||
|
|
||||||
|
|
||||||
|
def test_kv_cache_oom_no_memory():
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
config = MagicMock()
|
||||||
|
config.model_config.max_model_len = 2048
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"layer_0": FullAttentionSpec(
|
||||||
|
block_size=16,
|
||||||
|
num_kv_heads=8,
|
||||||
|
head_size=128,
|
||||||
|
dtype="float16",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
check_enough_kv_cache_memory(config, spec, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kv_cache_oom_insufficient_memory(monkeypatch):
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
config = MagicMock()
|
||||||
|
config.model_config.max_model_len = 2048
|
||||||
|
config.cache_config.block_size = 16
|
||||||
|
config.parallel_config.tensor_parallel_size = 1
|
||||||
|
config.parallel_config.pipeline_parallel_size = 1
|
||||||
|
config.parallel_config.decode_context_parallel_size = 1
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"vllm.v1.core.kv_cache_utils.max_memory_usage_bytes",
|
||||||
|
lambda c, s: 100 * 1024**3, # 100 GiB
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"layer_0": FullAttentionSpec(
|
||||||
|
block_size=16,
|
||||||
|
num_kv_heads=8,
|
||||||
|
head_size=128,
|
||||||
|
dtype="float16",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
check_enough_kv_cache_memory(config, spec, 1024**3) # 1 GiB
|
||||||
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
163
tests/v1/kv_connector/unit/test_cache_pollution_prevention.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
test that invalid blocks are evicted from prefix cache to prevent pollution.
|
||||||
|
|
||||||
|
verifies that when sync-loading fails, invalid blocks are removed from the
|
||||||
|
prefix cache hash table so future requests cannot match and reuse corrupted data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
create_model_runner_output,
|
||||||
|
create_request,
|
||||||
|
create_scheduler,
|
||||||
|
create_vllm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
def _make_get_num_new_matched_tokens(
|
||||||
|
req_num_new_matched_tokens: dict[str, int],
|
||||||
|
async_load: bool,
|
||||||
|
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||||
|
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||||
|
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||||
|
return value, async_load
|
||||||
|
|
||||||
|
return get_num_new_matched_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fail_scheduler():
|
||||||
|
"""scheduler with kv_load_failure_policy='fail'"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||||
|
return create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_blocks_evicted_prevents_cache_pollution(
|
||||||
|
fail_scheduler: Scheduler,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
verify invalid blocks are evicted to prevent future cache hits.
|
||||||
|
|
||||||
|
scenario:
|
||||||
|
1. request 1 loads externally-computed blocks (sync mode)
|
||||||
|
2. some blocks fail to load and are marked invalid
|
||||||
|
3. with fail policy, invalid blocks should be evicted from prefix cache
|
||||||
|
4. request is marked as FINISHED_ERROR
|
||||||
|
"""
|
||||||
|
num_prompt_blocks = 100
|
||||||
|
num_external_computed_blocks = 99
|
||||||
|
invalid_block_idx = 50
|
||||||
|
|
||||||
|
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||||
|
num_external_computed_tokens = (
|
||||||
|
num_external_computed_blocks * fail_scheduler.block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# request 1: will have invalid blocks
|
||||||
|
request1 = create_request(num_tokens=num_prompt_tokens, request_id=1)
|
||||||
|
fail_scheduler.add_request(request=request1)
|
||||||
|
|
||||||
|
req_num_new_matched_tokens = {
|
||||||
|
request1.request_id: num_external_computed_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# mock connector indicating sync load
|
||||||
|
fail_scheduler.connector = Mock()
|
||||||
|
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||||
|
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||||
|
)
|
||||||
|
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||||
|
fail_scheduler.connector.take_events.return_value = ()
|
||||||
|
|
||||||
|
scheduler_output = fail_scheduler.schedule()
|
||||||
|
|
||||||
|
# request should be running with sync KV load
|
||||||
|
assert len(fail_scheduler.running) == 1
|
||||||
|
assert request1.status == RequestStatus.RUNNING
|
||||||
|
|
||||||
|
# get allocated block IDs
|
||||||
|
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||||
|
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||||
|
invalid_block_ids = {invalid_block_id}
|
||||||
|
|
||||||
|
# get the block object to verify eviction later
|
||||||
|
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||||
|
|
||||||
|
# cache the blocks to simulate they've been computed and cached
|
||||||
|
# (in real scenario blocks would be cached after compute)
|
||||||
|
fail_scheduler.kv_cache_manager.cache_blocks(request1, num_external_computed_tokens)
|
||||||
|
|
||||||
|
# verify block has a hash (is cached) before reporting invalid blocks
|
||||||
|
assert block.block_hash is not None, (
|
||||||
|
f"block {invalid_block_id} should be cached (have a hash) before "
|
||||||
|
f"eviction test, but hash is None"
|
||||||
|
)
|
||||||
|
|
||||||
|
# report invalid blocks
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
[request1],
|
||||||
|
invalid_block_ids=invalid_block_ids,
|
||||||
|
use_eos=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# verify request finished with error (fail policy)
|
||||||
|
assert request1.status == RequestStatus.FINISHED_ERROR
|
||||||
|
|
||||||
|
# critical assertion: invalid block and all subsequent blocks should be evicted
|
||||||
|
# all blocks from invalid_block_idx onwards become invalid since they were
|
||||||
|
# computed based on the failed block
|
||||||
|
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||||
|
block_id = req_block_ids[idx]
|
||||||
|
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||||
|
assert block_obj.block_hash is None, (
|
||||||
|
f"block {block_id} at index {idx} should have been evicted "
|
||||||
|
f"(hash reset to None), but hash is {block_obj.block_hash}. "
|
||||||
|
f"All blocks from index {invalid_block_idx} onwards should be evicted "
|
||||||
|
f"since they depend on the invalid block at index {invalid_block_idx}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# verify cache contains exactly the valid blocks (before first affected block)
|
||||||
|
# and none of the invalid blocks (from first affected block onwards)
|
||||||
|
|
||||||
|
# valid blocks: all blocks before invalid_block_idx should be cached
|
||||||
|
for idx in range(invalid_block_idx):
|
||||||
|
block_id = req_block_ids[idx]
|
||||||
|
block_obj = fail_scheduler.kv_cache_manager.block_pool.blocks[block_id]
|
||||||
|
assert block_obj.block_hash is not None, (
|
||||||
|
f"valid block {block_id} at index {idx} should still be cached "
|
||||||
|
f"(have a hash), but hash is None. Only blocks from index "
|
||||||
|
f"{invalid_block_idx} onwards should be evicted."
|
||||||
|
)
|
||||||
|
|
||||||
|
# invalid blocks: verify they're not in the cached_block_hash_to_block map
|
||||||
|
cached_blocks = (
|
||||||
|
fail_scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
|
||||||
|
)
|
||||||
|
cached_block_ids = {
|
||||||
|
b.block_id
|
||||||
|
for blocks_val in cached_blocks._cache.values()
|
||||||
|
for b in (
|
||||||
|
[blocks_val] if not isinstance(blocks_val, dict) else blocks_val.values()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx in range(invalid_block_idx, len(req_block_ids)):
|
||||||
|
block_id = req_block_ids[idx]
|
||||||
|
assert block_id not in cached_block_ids, (
|
||||||
|
f"invalid block {block_id} at index {idx} should not be in cache hash table"
|
||||||
|
)
|
||||||
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
147
tests/v1/kv_connector/unit/test_error_propagation.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
create_model_runner_output,
|
||||||
|
create_request,
|
||||||
|
create_scheduler,
|
||||||
|
create_vllm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
def _make_get_num_new_matched_tokens(
|
||||||
|
req_num_new_matched_tokens: dict[str, int],
|
||||||
|
async_load: bool,
|
||||||
|
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||||
|
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||||
|
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||||
|
return value, async_load
|
||||||
|
|
||||||
|
return get_num_new_matched_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fail_scheduler():
|
||||||
|
"""scheduler with kv_load_failure_policy='fail'"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||||
|
return create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_propagation_sync_load(fail_scheduler: Scheduler):
|
||||||
|
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (sync load)"""
|
||||||
|
num_prompt_blocks = 100
|
||||||
|
num_external_computed_blocks = 99
|
||||||
|
invalid_block_idx = 50
|
||||||
|
|
||||||
|
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||||
|
num_external_computed_tokens = (
|
||||||
|
num_external_computed_blocks * fail_scheduler.block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
request = create_request(num_tokens=num_prompt_tokens)
|
||||||
|
fail_scheduler.add_request(request=request)
|
||||||
|
|
||||||
|
req_num_new_matched_tokens = {
|
||||||
|
request.request_id: num_external_computed_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
fail_scheduler.connector = Mock()
|
||||||
|
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||||
|
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||||
|
)
|
||||||
|
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||||
|
fail_scheduler.connector.take_events.return_value = ()
|
||||||
|
|
||||||
|
scheduler_output = fail_scheduler.schedule()
|
||||||
|
|
||||||
|
assert len(fail_scheduler.running) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
|
assert fail_scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||||
|
|
||||||
|
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||||
|
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
[request],
|
||||||
|
invalid_block_ids=invalid_block_ids,
|
||||||
|
use_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
assert request.status == RequestStatus.FINISHED_ERROR
|
||||||
|
assert request.get_finished_reason() == FinishReason.ERROR
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
engine_outputs = next(iter(outputs.values()))
|
||||||
|
assert len(engine_outputs.outputs) == 1
|
||||||
|
output = engine_outputs.outputs[0]
|
||||||
|
assert output.request_id == request.request_id
|
||||||
|
assert output.finish_reason == FinishReason.ERROR
|
||||||
|
|
||||||
|
assert len(fail_scheduler.running) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
||||||
|
"""test invalid_block_ids with fail policy -> FINISHED_ERROR (async load)"""
|
||||||
|
num_prompt_blocks = 100
|
||||||
|
num_external_computed_blocks = 99
|
||||||
|
invalid_block_idx = 50
|
||||||
|
|
||||||
|
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||||
|
num_external_computed_tokens = (
|
||||||
|
num_external_computed_blocks * fail_scheduler.block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
request = create_request(num_tokens=num_prompt_tokens)
|
||||||
|
fail_scheduler.add_request(request=request)
|
||||||
|
|
||||||
|
req_num_new_matched_tokens = {
|
||||||
|
request.request_id: num_external_computed_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
fail_scheduler.connector = Mock()
|
||||||
|
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||||
|
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||||
|
)
|
||||||
|
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||||
|
fail_scheduler.connector.take_events.return_value = ()
|
||||||
|
|
||||||
|
scheduler_output = fail_scheduler.schedule()
|
||||||
|
|
||||||
|
assert len(fail_scheduler.waiting) == 1
|
||||||
|
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
assert request.num_computed_tokens == 0
|
||||||
|
|
||||||
|
(req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
|
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
reqs=[],
|
||||||
|
finished_recving=set(),
|
||||||
|
invalid_block_ids=invalid_block_ids,
|
||||||
|
use_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
assert request.status == RequestStatus.FINISHED_ERROR
|
||||||
|
assert request.get_finished_reason() == FinishReason.ERROR
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
engine_outputs = next(iter(outputs.values()))
|
||||||
|
assert len(engine_outputs.outputs) == 1
|
||||||
|
output = engine_outputs.outputs[0]
|
||||||
|
assert output.request_id == request.request_id
|
||||||
|
assert output.finish_reason == FinishReason.ERROR
|
||||||
|
|
||||||
|
assert len(fail_scheduler.waiting) == 0
|
||||||
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
454
tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py
Normal file
@ -0,0 +1,454 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests for correctness in invalid block handling.
|
||||||
|
|
||||||
|
These tests verify correct behavior in three scenarios:
|
||||||
|
1. Sync recompute case: Blocks should not be freed for running requests
|
||||||
|
that need to recompute invalid blocks
|
||||||
|
2. Sync fail case: Invalid blocks must be evicted from cache when request fails
|
||||||
|
3. Async recompute case: Invalid blocks should not be cached after transfer
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
|
from vllm.v1.request import FinishReason, Request, RequestStatus
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
create_model_runner_output,
|
||||||
|
create_request,
|
||||||
|
create_scheduler,
|
||||||
|
create_vllm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
def _make_get_num_new_matched_tokens(
|
||||||
|
req_num_new_matched_tokens: dict[str, int],
|
||||||
|
async_load: bool,
|
||||||
|
) -> Callable[[Request, int], tuple[int, bool]]:
|
||||||
|
def get_num_new_matched_tokens(request: Request, _: int) -> tuple[int, bool]:
|
||||||
|
value = req_num_new_matched_tokens.get(request.request_id, 0)
|
||||||
|
return value, async_load
|
||||||
|
|
||||||
|
return get_num_new_matched_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fail_scheduler():
|
||||||
|
"""scheduler with kv_load_failure_policy='fail'"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
vllm_config.kv_transfer_config.kv_load_failure_policy = "fail"
|
||||||
|
return create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def recompute_scheduler():
|
||||||
|
"""scheduler with kv_load_failure_policy='recompute'"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
vllm_config.kv_transfer_config.kv_load_failure_policy = "recompute"
|
||||||
|
return create_scheduler(vllm_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_recompute_blocks_not_freed_for_running_requests(
|
||||||
|
recompute_scheduler: Scheduler,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test sync recompute case - blocks must not be freed for running requests.
|
||||||
|
|
||||||
|
When a running request has invalid blocks and retry_policy is 'recompute':
|
||||||
|
1. Request should remain in RUNNING state
|
||||||
|
2. num_computed_tokens should be truncated to invalid block boundary
|
||||||
|
3. Blocks should NOT be freed (request still needs them for recomputation)
|
||||||
|
4. Request should remain in scheduler.requests and scheduler.running
|
||||||
|
"""
|
||||||
|
num_prompt_blocks = 100
|
||||||
|
num_external_computed_blocks = 99
|
||||||
|
invalid_block_idx = 50
|
||||||
|
|
||||||
|
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||||
|
num_external_computed_tokens = (
|
||||||
|
num_external_computed_blocks * recompute_scheduler.block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
request = create_request(num_tokens=num_prompt_tokens)
|
||||||
|
recompute_scheduler.add_request(request=request)
|
||||||
|
|
||||||
|
req_num_new_matched_tokens = {
|
||||||
|
request.request_id: num_external_computed_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# mock connector indicating sync load
|
||||||
|
recompute_scheduler.connector = Mock()
|
||||||
|
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||||
|
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||||
|
)
|
||||||
|
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||||
|
recompute_scheduler.connector.take_events.return_value = ()
|
||||||
|
|
||||||
|
scheduler_output = recompute_scheduler.schedule()
|
||||||
|
|
||||||
|
# request should be running with sync KV load
|
||||||
|
assert len(recompute_scheduler.running) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
|
assert request.status == RequestStatus.RUNNING
|
||||||
|
|
||||||
|
# get the allocated block IDs before invalid blocks are reported
|
||||||
|
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||||
|
invalid_block_ids = {req_block_ids[invalid_block_idx]}
|
||||||
|
|
||||||
|
# store original num_computed_tokens for comparison
|
||||||
|
original_num_computed_tokens = request.num_computed_tokens
|
||||||
|
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
[request],
|
||||||
|
invalid_block_ids=invalid_block_ids,
|
||||||
|
use_eos=False, # not finished - should continue running
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = recompute_scheduler.update_from_output(
|
||||||
|
scheduler_output, model_runner_output
|
||||||
|
)
|
||||||
|
|
||||||
|
# critical assertions for recompute case:
|
||||||
|
|
||||||
|
# 1. request should still be RUNNING (not finished, not aborted)
|
||||||
|
assert request.status == RequestStatus.RUNNING, (
|
||||||
|
f"Request should remain RUNNING for recompute, got {request.status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. num_computed_tokens should be truncated to first invalid block
|
||||||
|
expected_truncated_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||||
|
assert request.num_computed_tokens == expected_truncated_tokens, (
|
||||||
|
f"num_computed_tokens should be truncated to {expected_truncated_tokens}, "
|
||||||
|
f"got {request.num_computed_tokens}"
|
||||||
|
)
|
||||||
|
assert request.num_computed_tokens < original_num_computed_tokens, (
|
||||||
|
"num_computed_tokens should be reduced after invalid block detection"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. no output should be generated (request is still running)
|
||||||
|
# the request should be skipped in the output loop
|
||||||
|
assert len(outputs) == 0 or request.request_id not in [
|
||||||
|
out.request_id for outs in outputs.values() for out in outs.outputs
|
||||||
|
], "No output should be generated for recompute requests"
|
||||||
|
|
||||||
|
# 4. request should still be in running queue
|
||||||
|
assert request in recompute_scheduler.running, (
|
||||||
|
"Request should remain in running queue for recomputation"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. request should still be in scheduler.requests (not deleted)
|
||||||
|
assert request.request_id in recompute_scheduler.requests, (
|
||||||
|
"Request should not be deleted from scheduler.requests"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. blocks should NOT be freed - verify blocks are still allocated
|
||||||
|
try:
|
||||||
|
allocated_blocks = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||||
|
request.request_id
|
||||||
|
)
|
||||||
|
assert allocated_blocks is not None
|
||||||
|
assert len(allocated_blocks[0]) > 0, (
|
||||||
|
"Blocks should still be allocated for recomputation"
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
pytest.fail(
|
||||||
|
"Blocks were freed incorrectly! Running requests need their blocks "
|
||||||
|
"to recompute invalid portions."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. verify request can be rescheduled in next step
|
||||||
|
scheduler_output_2 = recompute_scheduler.schedule()
|
||||||
|
|
||||||
|
# request should appear in the new schedule to recompute invalid blocks
|
||||||
|
scheduled_req_ids = [
|
||||||
|
req.request_id for req in scheduler_output_2.scheduled_new_reqs
|
||||||
|
]
|
||||||
|
if scheduler_output_2.num_scheduled_tokens:
|
||||||
|
scheduled_req_ids.extend(scheduler_output_2.num_scheduled_tokens.keys())
|
||||||
|
|
||||||
|
assert (
|
||||||
|
request.request_id in scheduled_req_ids or len(recompute_scheduler.running) > 0
|
||||||
|
), "Request should be reschedulable for recomputation"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
|
||||||
|
"""
|
||||||
|
Test sync fail case - invalid blocks must be evicted from cache.
|
||||||
|
|
||||||
|
When a request fails with policy='fail' and has invalid blocks from sync loading:
|
||||||
|
1. Request should be finished with FINISHED_ERROR
|
||||||
|
2. Invalid blocks should be evicted from the KV cache
|
||||||
|
3. Valid blocks (if shared) should remain in cache
|
||||||
|
4. Future requests should not reuse the invalid blocks
|
||||||
|
|
||||||
|
This test verifies that invalid blocks are properly evicted to prevent
|
||||||
|
cache corruption and reuse of invalid data.
|
||||||
|
"""
|
||||||
|
num_prompt_blocks = 100
|
||||||
|
num_external_computed_blocks = 99
|
||||||
|
invalid_block_idx = 50
|
||||||
|
|
||||||
|
num_prompt_tokens = num_prompt_blocks * fail_scheduler.block_size
|
||||||
|
num_external_computed_tokens = (
|
||||||
|
num_external_computed_blocks * fail_scheduler.block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
request = create_request(num_tokens=num_prompt_tokens)
|
||||||
|
fail_scheduler.add_request(request=request)
|
||||||
|
|
||||||
|
req_num_new_matched_tokens = {
|
||||||
|
request.request_id: num_external_computed_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# mock connector indicating sync load
|
||||||
|
fail_scheduler.connector = Mock()
|
||||||
|
fail_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||||
|
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, False)
|
||||||
|
)
|
||||||
|
fail_scheduler.connector.request_finished.return_value = (False, None)
|
||||||
|
fail_scheduler.connector.take_events.return_value = ()
|
||||||
|
|
||||||
|
scheduler_output = fail_scheduler.schedule()
|
||||||
|
|
||||||
|
# request should be running with sync KV load
|
||||||
|
assert len(fail_scheduler.running) == 1
|
||||||
|
assert request.status == RequestStatus.RUNNING
|
||||||
|
|
||||||
|
# get allocated block IDs
|
||||||
|
req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[0]
|
||||||
|
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||||
|
invalid_block_ids = {invalid_block_id}
|
||||||
|
|
||||||
|
# verify the block is in the block pool before we report it as invalid
|
||||||
|
block = fail_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||||
|
assert block is not None
|
||||||
|
|
||||||
|
# report invalid blocks - request should fail
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
[request],
|
||||||
|
invalid_block_ids=invalid_block_ids,
|
||||||
|
use_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = fail_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# verify request is finished with error
|
||||||
|
assert request.status == RequestStatus.FINISHED_ERROR
|
||||||
|
assert request.get_finished_reason() == FinishReason.ERROR
|
||||||
|
|
||||||
|
# verify output is generated
|
||||||
|
assert len(outputs) == 1
|
||||||
|
engine_outputs = next(iter(outputs.values()))
|
||||||
|
assert len(engine_outputs.outputs) == 1
|
||||||
|
output = engine_outputs.outputs[0]
|
||||||
|
assert output.request_id == request.request_id
|
||||||
|
assert output.finish_reason == FinishReason.ERROR
|
||||||
|
|
||||||
|
# verify the request was removed from scheduler
|
||||||
|
assert request.request_id not in fail_scheduler.requests
|
||||||
|
assert len(fail_scheduler.running) == 0
|
||||||
|
|
||||||
|
# critical: verify invalid block was actually freed from cache
|
||||||
|
# this is the key assertion - the invalid block should no longer be
|
||||||
|
# tracked by the KV cache manager for this request
|
||||||
|
# if it's still there, a future request could reuse the invalid data
|
||||||
|
try:
|
||||||
|
block_ids = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id)
|
||||||
|
# if we get here, check if blocks were actually freed
|
||||||
|
if block_ids is not None and len(block_ids[0]) > 0:
|
||||||
|
pytest.fail(
|
||||||
|
f"Invalid blocks still tracked for finished request! "
|
||||||
|
f"Request {request.request_id} should have been freed but "
|
||||||
|
f"still has {len(block_ids[0])} blocks allocated."
|
||||||
|
)
|
||||||
|
# blocks list exists but is empty - this is fine, they were freed
|
||||||
|
except KeyError:
|
||||||
|
# expected - request completely removed from tracking
|
||||||
|
pass
|
||||||
|
|
||||||
|
# critical: verify invalid block was evicted from prefix cache
|
||||||
|
# the block should no longer have a hash (hash is reset on eviction)
|
||||||
|
assert block.block_hash is None, (
|
||||||
|
f"Invalid block {invalid_block_id} should have been evicted from cache "
|
||||||
|
f"(hash should be None), but hash is still {block.block_hash}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_recompute_blocks_not_cached_when_invalid(
|
||||||
|
recompute_scheduler: Scheduler,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test async recompute case - invalid blocks not cached after transfer.
|
||||||
|
|
||||||
|
When async KV loading has invalid blocks and retry_policy is 'recompute':
|
||||||
|
1. Blocks are allocated but not cached yet
|
||||||
|
2. When async transfer completes, only valid blocks should be cached
|
||||||
|
3. Invalid blocks should never enter the prefix cache
|
||||||
|
|
||||||
|
This test verifies correctness, the failed_recving_kv_req_ids protection
|
||||||
|
ensures only valid blocks are cached when the transfer completes, and we
|
||||||
|
only evict blocks from cache that are already hashed in the block table.
|
||||||
|
"""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
num_prompt_blocks = 100
|
||||||
|
num_external_computed_blocks = 99
|
||||||
|
invalid_block_idx = 50
|
||||||
|
|
||||||
|
num_prompt_tokens = num_prompt_blocks * recompute_scheduler.block_size
|
||||||
|
num_external_computed_tokens = (
|
||||||
|
num_external_computed_blocks * recompute_scheduler.block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
request = create_request(num_tokens=num_prompt_tokens)
|
||||||
|
recompute_scheduler.add_request(request=request)
|
||||||
|
|
||||||
|
req_num_new_matched_tokens = {
|
||||||
|
request.request_id: num_external_computed_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
# mock connector indicating async load
|
||||||
|
recompute_scheduler.connector = Mock()
|
||||||
|
recompute_scheduler.connector.get_num_new_matched_tokens.side_effect = (
|
||||||
|
_make_get_num_new_matched_tokens(req_num_new_matched_tokens, True)
|
||||||
|
)
|
||||||
|
recompute_scheduler.connector.request_finished.return_value = (False, None)
|
||||||
|
recompute_scheduler.connector.take_events.return_value = ()
|
||||||
|
|
||||||
|
scheduler_output = recompute_scheduler.schedule()
|
||||||
|
|
||||||
|
# request should be waiting for remote KVs
|
||||||
|
assert len(recompute_scheduler.waiting) == 1
|
||||||
|
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
assert request.num_computed_tokens == 0
|
||||||
|
|
||||||
|
# get the allocated block IDs
|
||||||
|
(req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids(
|
||||||
|
request.request_id
|
||||||
|
)
|
||||||
|
invalid_block_id = req_block_ids[invalid_block_idx]
|
||||||
|
invalid_block_ids = {invalid_block_id}
|
||||||
|
|
||||||
|
# get the block object to verify it's not cached yet and stays uncached
|
||||||
|
block = recompute_scheduler.kv_cache_manager.block_pool.blocks[invalid_block_id]
|
||||||
|
|
||||||
|
# verify block has no hash before invalid blocks are reported
|
||||||
|
assert block.block_hash is None, (
|
||||||
|
"Async loading blocks should not be cached yet (no hash)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# report invalid blocks (transfer not finished yet)
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
reqs=[],
|
||||||
|
finished_recving=None, # transfer NOT finished
|
||||||
|
invalid_block_ids=invalid_block_ids,
|
||||||
|
use_eos=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# critical: spy on evict_blocks to verify it's NOT called for async blocks
|
||||||
|
original_evict_blocks = recompute_scheduler.kv_cache_manager.evict_blocks
|
||||||
|
evict_blocks_calls = []
|
||||||
|
|
||||||
|
def evict_blocks_spy(block_ids):
|
||||||
|
evict_blocks_calls.append(set(block_ids))
|
||||||
|
return original_evict_blocks(block_ids)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
|
||||||
|
):
|
||||||
|
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# verify evict_blocks was NOT called (async blocks excluded from eviction)
|
||||||
|
assert len(evict_blocks_calls) == 0, (
|
||||||
|
f"evict_blocks should not be called for async-only invalid blocks, "
|
||||||
|
f"but was called {len(evict_blocks_calls)} time(s) with {evict_blocks_calls}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request should still be waiting (not finished with error due to recompute policy)
|
||||||
|
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||||
|
|
||||||
|
# verify num_computed_tokens was truncated to before invalid block
|
||||||
|
expected_valid_tokens = invalid_block_idx * recompute_scheduler.block_size
|
||||||
|
assert request.num_computed_tokens == expected_valid_tokens
|
||||||
|
|
||||||
|
# verify invalid block still has no hash (was not evicted)
|
||||||
|
assert block.block_hash is None, (
|
||||||
|
f"Async loading blocks shouldn't be cached or evicted. "
|
||||||
|
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# now simulate async transfer completing
|
||||||
|
model_runner_output_2 = create_model_runner_output(
|
||||||
|
reqs=[],
|
||||||
|
finished_recving={request.request_id},
|
||||||
|
invalid_block_ids=None,
|
||||||
|
use_eos=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
recompute_scheduler.update_from_output(scheduler_output, model_runner_output_2)
|
||||||
|
|
||||||
|
# verify request is now marked as finished receiving and ready to be processed
|
||||||
|
assert request.request_id in recompute_scheduler.finished_recving_kv_req_ids
|
||||||
|
assert request.request_id in recompute_scheduler.failed_recving_kv_req_ids
|
||||||
|
|
||||||
|
# critical: verify invalid block still has no hash before recompute
|
||||||
|
# the async transfer invalid data was never cached
|
||||||
|
assert block.block_hash is None, (
|
||||||
|
f"Invalid block {invalid_block_id} should not be cached before recompute "
|
||||||
|
f"(hash should be None), but hash is {block.block_hash}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# critical end-to-end test: spy on cache_blocks to verify it's called with
|
||||||
|
# the truncated num_computed_tokens value
|
||||||
|
original_cache_blocks = recompute_scheduler.kv_cache_manager.cache_blocks
|
||||||
|
cache_blocks_calls = []
|
||||||
|
|
||||||
|
def cache_blocks_spy(req, num_tokens):
|
||||||
|
cache_blocks_calls.append((req.request_id, num_tokens))
|
||||||
|
return original_cache_blocks(req, num_tokens)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
recompute_scheduler.kv_cache_manager, "cache_blocks", cache_blocks_spy
|
||||||
|
):
|
||||||
|
# call schedule() again - this triggers _update_waiting_for_remote_kv()
|
||||||
|
# which should call cache_blocks with the truncated value
|
||||||
|
recompute_scheduler.schedule()
|
||||||
|
|
||||||
|
# verify cache_blocks was called with the truncated value
|
||||||
|
assert len(cache_blocks_calls) == 1, (
|
||||||
|
f"cache_blocks should be called exactly once, "
|
||||||
|
f"got {len(cache_blocks_calls)} calls"
|
||||||
|
)
|
||||||
|
cached_req_id, cached_num_tokens = cache_blocks_calls[0]
|
||||||
|
assert cached_req_id == request.request_id
|
||||||
|
assert cached_num_tokens == expected_valid_tokens, (
|
||||||
|
f"cache_blocks should be called with truncated value {expected_valid_tokens}, "
|
||||||
|
f"but was called with {cached_num_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request should now be RUNNING (scheduled immediately after transfer completes)
|
||||||
|
# the flow is: WAITING_FOR_REMOTE_KVS -> WAITING -> RUNNING in same schedule() call
|
||||||
|
assert request.status == RequestStatus.RUNNING
|
||||||
|
|
||||||
|
# num_computed_tokens should be >= expected_valid_tokens because the scheduler
|
||||||
|
# will schedule additional new tokens (up to max_num_batched_tokens) for the request
|
||||||
|
assert request.num_computed_tokens >= expected_valid_tokens, (
|
||||||
|
f"num_computed_tokens should be at least {expected_valid_tokens}, "
|
||||||
|
f"got {request.num_computed_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# request should no longer be in the failed/finished receiving sets
|
||||||
|
assert request.request_id not in recompute_scheduler.failed_recving_kv_req_ids
|
||||||
|
assert request.request_id not in recompute_scheduler.finished_recving_kv_req_ids
|
||||||
|
|
||||||
|
# request should be in the running queue
|
||||||
|
assert request in recompute_scheduler.running
|
||||||
@ -88,8 +88,8 @@ def forward_attention(
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
query_start_loc_cpu=query_start_loc.cpu(),
|
query_start_loc_cpu=query_start_loc.cpu(),
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_cpu=seq_lens.cpu(),
|
_seq_lens_cpu=seq_lens.cpu(),
|
||||||
num_computed_tokens_cpu=context_lens.cpu(),
|
_num_computed_tokens_cpu=context_lens.cpu(),
|
||||||
num_reqs=batch_size,
|
num_reqs=batch_size,
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ Here we break down the requirements in 2 steps:
|
|||||||
1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this.
|
1. Build and install the Python libraries (both [pplx-kernels](https://github.com/ppl-ai/pplx-kernels) and [DeepEP](https://github.com/deepseek-ai/DeepEP)), including necessary dependencies like NVSHMEM. This step does not require any privileged access. Any user can do this.
|
||||||
2. Configure NVIDIA driver to enable IBGDA. This step requires root access, and must be done on the host machine.
|
2. Configure NVIDIA driver to enable IBGDA. This step requires root access, and must be done on the host machine.
|
||||||
|
|
||||||
2 is necessary for multi-node deployment.
|
Step 2 is necessary for multi-node deployment.
|
||||||
|
|
||||||
All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`.
|
All scripts accept a positional argument as workspace path for staging the build, defaulting to `$(pwd)/ep_kernels_workspace`.
|
||||||
|
|
||||||
@ -23,6 +23,6 @@ TORCH_CUDA_ARCH_LIST="10.0" bash install_python_libraries.sh
|
|||||||
Additional step for multi-node deployment:
|
Additional step for multi-node deployment:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
sudo bash configure_system_drivers.sh
|
sudo bash configure_system_drivers.sh # update-initramfs can take several minutes
|
||||||
sudo reboot # Reboot is required to load the new driver
|
sudo reboot # Reboot is required to load the new driver
|
||||||
```
|
```
|
||||||
|
|||||||
@ -24,6 +24,15 @@ def is_aiter_found() -> bool:
|
|||||||
# we keep this global outside to not cause torch compile breaks.
|
# we keep this global outside to not cause torch compile breaks.
|
||||||
IS_AITER_FOUND = is_aiter_found()
|
IS_AITER_FOUND = is_aiter_found()
|
||||||
|
|
||||||
|
# Can't use dtypes.fp8 directly inside an op
|
||||||
|
# because it returns wrong result on gfx942.
|
||||||
|
# This is a workaround to get the correct FP8 dtype.
|
||||||
|
# This might because that the get_gfx() is wrapped as a custom op.
|
||||||
|
if IS_AITER_FOUND:
|
||||||
|
from aiter import dtypes
|
||||||
|
|
||||||
|
AITER_FP8_DTYPE = dtypes.fp8
|
||||||
|
|
||||||
|
|
||||||
def if_aiter_supported(func: Callable) -> Callable:
|
def if_aiter_supported(func: Callable) -> Callable:
|
||||||
"""Decorator that only executes the function if
|
"""Decorator that only executes the function if
|
||||||
@ -45,36 +54,6 @@ def if_aiter_supported(func: Callable) -> Callable:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_group_fp8_quant_impl(
|
|
||||||
x: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
|
||||||
from aiter import QuantType, dtypes, get_hip_quant
|
|
||||||
|
|
||||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
|
||||||
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
|
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_group_fp8_quant_fake(
|
|
||||||
x: torch.Tensor,
|
|
||||||
group_size: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
from aiter import dtypes
|
|
||||||
|
|
||||||
M, N = x.shape
|
|
||||||
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
|
|
||||||
out_bs = torch.empty(
|
|
||||||
(
|
|
||||||
M,
|
|
||||||
(N + group_size - 1) // group_size,
|
|
||||||
),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
return x_fp8, out_bs
|
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_fused_moe_impl(
|
def _rocm_aiter_fused_moe_impl(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@ -522,6 +501,142 @@ def _rocm_aiter_per_token_quant_fake(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
|
||||||
|
|
||||||
|
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
variance_epsilon,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
group_size=group_size,
|
||||||
|
dtype_quant=AITER_FP8_DTYPE,
|
||||||
|
res1=residual,
|
||||||
|
)
|
||||||
|
return (x_quant, x_quant_scales, res)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||||
|
return (
|
||||||
|
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||||
|
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||||
|
torch.empty_like(residual, device=residual.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant
|
||||||
|
|
||||||
|
(x_quant, x_quant_scales), _, _, res = fused_rms_fp8_group_quant(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
variance_epsilon,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
group_size=group_size,
|
||||||
|
dtype_quant=AITER_FP8_DTYPE,
|
||||||
|
res1=None,
|
||||||
|
)
|
||||||
|
return (x_quant, x_quant_scales)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
variance_epsilon: float,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||||
|
return (
|
||||||
|
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||||
|
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_group_fp8_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
||||||
|
from aiter import QuantType, get_hip_quant
|
||||||
|
|
||||||
|
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||||
|
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_group_fp8_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
|
||||||
|
out_bs = torch.empty(
|
||||||
|
(
|
||||||
|
M,
|
||||||
|
(N + group_size - 1) // group_size,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
return x_fp8, out_bs
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
from aiter.ops.triton.activation import act_mul_and_fp8_group_quant
|
||||||
|
|
||||||
|
return act_mul_and_fp8_group_quant(
|
||||||
|
x,
|
||||||
|
activation="silu",
|
||||||
|
group_size=group_size,
|
||||||
|
dtype_quant=AITER_FP8_DTYPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
group_size: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
M, N = x.shape
|
||||||
|
assert N % 2 == 0
|
||||||
|
N_half = N // 2
|
||||||
|
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
|
||||||
|
out_bs = torch.empty(
|
||||||
|
(
|
||||||
|
M,
|
||||||
|
(N_half + group_size - 1) // group_size,
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
return x_fp8, out_bs
|
||||||
|
|
||||||
|
|
||||||
# Global flag to ensure ops are registered only once
|
# Global flag to ensure ops are registered only once
|
||||||
_OPS_REGISTERED = False
|
_OPS_REGISTERED = False
|
||||||
|
|
||||||
@ -557,7 +672,7 @@ class rocm_aiter_ops:
|
|||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_linear_fp8_enaled(cls) -> bool:
|
def is_linear_fp8_enaled(cls) -> bool:
|
||||||
""" "Verifies device specs and availability of env variable."""
|
""" "Verifies device specs and availability of env variable."""
|
||||||
return cls.is_linear_enabled() and current_platform.is_fp8_fnuz()
|
return cls.is_linear_enabled()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
@ -632,14 +747,6 @@ class rocm_aiter_ops:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# register all the custom ops here
|
# register all the custom ops here
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="rocm_aiter_group_fp8_quant",
|
|
||||||
op_func=_rocm_aiter_group_fp8_quant_impl,
|
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_asm_moe_tkw1",
|
op_name="rocm_aiter_asm_moe_tkw1",
|
||||||
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
||||||
@ -699,27 +806,46 @@ class rocm_aiter_ops:
|
|||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
||||||
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_rms_norm",
|
op_name="rocm_aiter_rms_norm",
|
||||||
op_func=_rocm_aiter_rms_norm_impl,
|
op_func=_rocm_aiter_rms_norm_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_rocm_aiter_rms_norm_fake,
|
fake_impl=_rocm_aiter_rms_norm_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
|
||||||
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
op_func=_rocm_aiter_rmsnorm2d_fwd_with_add_impl,
|
||||||
mutates_args=[],
|
|
||||||
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
fake_impl=_rocm_aiter_rmsnorm2d_fwd_with_add_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rmsnorm_fp8_group_quant",
|
||||||
|
op_func=_rocm_aiter_rmsnorm_fp8_group_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_rmsnorm_fp8_group_quant_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rmsnorm_with_add_fp8_group_quant",
|
||||||
|
op_func=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_act_mul_and_fp8_group_quant",
|
||||||
|
op_func=_rocm_aiter_act_mul_and_fp8_group_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_act_mul_and_fp8_group_quant_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_group_fp8_quant",
|
||||||
|
op_func=_rocm_aiter_group_fp8_quant_impl,
|
||||||
|
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
||||||
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_per_tensor_quant",
|
op_name="rocm_aiter_per_tensor_quant",
|
||||||
op_func=_rocm_aiter_per_tensor_quant_impl,
|
op_func=_rocm_aiter_per_tensor_quant_impl,
|
||||||
|
|||||||
@ -294,6 +294,12 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
# Some features like decode context parallelism require the softmax lse.
|
# Some features like decode context parallelism require the softmax lse.
|
||||||
can_return_lse_for_decode: bool = False
|
can_return_lse_for_decode: bool = False
|
||||||
|
|
||||||
|
# Whether the attention impl supports Prefill Context Parallelism.
|
||||||
|
supports_pcp: bool = False
|
||||||
|
# Whether the attention impl(or ops) supports MTP
|
||||||
|
# when cp_kv_cache_interleave_size > 1
|
||||||
|
supports_mtp_with_cp_non_trivial_interleave_size: bool = False
|
||||||
|
|
||||||
# some attention backends might not always want to return lse
|
# some attention backends might not always want to return lse
|
||||||
# even if they can return lse (for efficiency reasons)
|
# even if they can return lse (for efficiency reasons)
|
||||||
need_to_return_lse_for_decode: bool = False
|
need_to_return_lse_for_decode: bool = False
|
||||||
|
|||||||
@ -252,35 +252,3 @@ def register_backend(
|
|||||||
return lambda x: x
|
return lambda x: x
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
# Backwards compatibility alias for plugins
|
|
||||||
class _BackendMeta(type):
|
|
||||||
"""Metaclass to provide deprecation warnings when accessing _Backend."""
|
|
||||||
|
|
||||||
def __getattribute__(cls, name: str):
|
|
||||||
if name not in ("__class__", "__mro__", "__name__"):
|
|
||||||
logger.warning(
|
|
||||||
"_Backend has been renamed to AttentionBackendEnum. "
|
|
||||||
"Please update your code to use AttentionBackendEnum instead. "
|
|
||||||
"_Backend will be removed in a future release."
|
|
||||||
)
|
|
||||||
return getattr(AttentionBackendEnum, name)
|
|
||||||
|
|
||||||
def __getitem__(cls, name: str):
|
|
||||||
logger.warning(
|
|
||||||
"_Backend has been renamed to AttentionBackendEnum. "
|
|
||||||
"Please update your code to use AttentionBackendEnum instead. "
|
|
||||||
"_Backend will be removed in a future release."
|
|
||||||
)
|
|
||||||
return AttentionBackendEnum[name]
|
|
||||||
|
|
||||||
|
|
||||||
class _Backend(metaclass=_BackendMeta):
|
|
||||||
"""Deprecated: Use AttentionBackendEnum instead.
|
|
||||||
|
|
||||||
This class is provided for backwards compatibility with plugins
|
|
||||||
and will be removed in a future release.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|||||||
@ -103,7 +103,7 @@ def create_cross_attention_backend(
|
|||||||
# needed here to know how many tokens to attend to from the cached
|
# needed here to know how many tokens to attend to from the cached
|
||||||
# cross-attention KV cache.
|
# cross-attention KV cache.
|
||||||
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||||
new_metadata.seq_lens_cpu = torch.from_numpy(
|
new_metadata._seq_lens_cpu = torch.from_numpy(
|
||||||
common_attn_metadata.encoder_seq_lens_cpu
|
common_attn_metadata.encoder_seq_lens_cpu
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import inspect
|
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from typing import cast, get_args
|
from typing import cast, get_args
|
||||||
|
|
||||||
@ -73,39 +72,18 @@ def _cached_get_attn_backend(
|
|||||||
) -> type[AttentionBackend]:
|
) -> type[AttentionBackend]:
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
sig = inspect.signature(current_platform.get_attn_backend_cls)
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
if "use_v1" in sig.parameters:
|
backend,
|
||||||
logger.warning_once(
|
head_size,
|
||||||
"use_v1 parameter for get_attn_backend_cls is deprecated and will "
|
dtype,
|
||||||
"be removed in v0.13.0 or v1.0.0, whichever is soonest. Please "
|
kv_cache_dtype,
|
||||||
"remove it from your plugin code."
|
block_size,
|
||||||
)
|
use_mla,
|
||||||
attention_cls = current_platform.get_attn_backend_cls(
|
has_sink,
|
||||||
backend,
|
use_sparse,
|
||||||
head_size,
|
use_mm_prefix,
|
||||||
dtype,
|
attn_type,
|
||||||
kv_cache_dtype,
|
)
|
||||||
block_size,
|
|
||||||
True, # use_v1
|
|
||||||
use_mla,
|
|
||||||
has_sink,
|
|
||||||
use_sparse,
|
|
||||||
use_mm_prefix,
|
|
||||||
attn_type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attention_cls = current_platform.get_attn_backend_cls(
|
|
||||||
backend,
|
|
||||||
head_size,
|
|
||||||
dtype,
|
|
||||||
kv_cache_dtype,
|
|
||||||
block_size,
|
|
||||||
use_mla,
|
|
||||||
has_sink,
|
|
||||||
use_sparse,
|
|
||||||
use_mm_prefix,
|
|
||||||
attn_type,
|
|
||||||
)
|
|
||||||
if not attention_cls:
|
if not attention_cls:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid attention backend for {current_platform.device_name}"
|
f"Invalid attention backend for {current_platform.device_name}"
|
||||||
|
|||||||
@ -788,7 +788,7 @@ async def benchmark(
|
|||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
"{:<40} {:<10.2f}".format(
|
"{:<40} {:<10.2f}".format(
|
||||||
"Total Token throughput (tok/s):", metrics.total_token_throughput
|
"Total token throughput (tok/s):", metrics.total_token_throughput
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import functools
|
|||||||
from torch import fx as fx
|
from torch import fx as fx
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -13,6 +14,12 @@ from vllm.utils.system_utils import set_env_var
|
|||||||
from .post_cleanup import PostCleanupPass
|
from .post_cleanup import PostCleanupPass
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
from vllm.compilation.rocm_aiter_fusion import (
|
||||||
|
RocmAiterRMSNormFp8GroupQuantFusionPass,
|
||||||
|
RocmAiterSiluMulFp8GroupQuantFusionPass,
|
||||||
|
)
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .activation_quant_fusion import ActivationQuantFusionPass
|
from .activation_quant_fusion import ActivationQuantFusionPass
|
||||||
from .fusion import RMSNormQuantFusionPass
|
from .fusion import RMSNormQuantFusionPass
|
||||||
@ -109,8 +116,12 @@ class PostGradPassManager(CustomGraphPass):
|
|||||||
|
|
||||||
if self.pass_config.fuse_norm_quant:
|
if self.pass_config.fuse_norm_quant:
|
||||||
self.passes += [RMSNormQuantFusionPass(config)]
|
self.passes += [RMSNormQuantFusionPass(config)]
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
self.passes += [RocmAiterRMSNormFp8GroupQuantFusionPass(config)]
|
||||||
if self.pass_config.fuse_act_quant:
|
if self.pass_config.fuse_act_quant:
|
||||||
self.passes += [ActivationQuantFusionPass(config)]
|
self.passes += [ActivationQuantFusionPass(config)]
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
self.passes += [RocmAiterSiluMulFp8GroupQuantFusionPass(config)]
|
||||||
|
|
||||||
if self.pass_config.fuse_attn_quant:
|
if self.pass_config.fuse_attn_quant:
|
||||||
self.passes += [AttnFusionPass(config)]
|
self.passes += [AttnFusionPass(config)]
|
||||||
|
|||||||
242
vllm/compilation/rocm_aiter_fusion.py
Normal file
242
vllm/compilation/rocm_aiter_fusion.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch._inductor.pattern_matcher as pm
|
||||||
|
from torch import fx
|
||||||
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
|
||||||
|
from vllm.compilation.activation_quant_fusion import ActivationQuantPattern
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .fusion import empty_bf16
|
||||||
|
from .inductor_pass import enable_fake_mode
|
||||||
|
from .matcher_utils import MatcherSiluAndMul
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
AITER_RMS_GROUP_QUANT_OP = torch.ops.vllm.rocm_aiter_rmsnorm_fp8_group_quant.default
|
||||||
|
AITER_RMS_ADD_GROUP_QUANT_OP = (
|
||||||
|
torch.ops.vllm.rocm_aiter_rmsnorm_with_add_fp8_group_quant.default
|
||||||
|
)
|
||||||
|
|
||||||
|
AITER_RMS_OP = torch.ops.vllm.rocm_aiter_rms_norm.default
|
||||||
|
AITER_RMS_ADD_OP = torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add.default
|
||||||
|
|
||||||
|
AITER_GROUP_FP8_QUANT_OP = torch.ops.vllm.rocm_aiter_group_fp8_quant.default
|
||||||
|
TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
|
||||||
|
|
||||||
|
FUSED_SILU_MUL_QUANT_OP = torch.ops.vllm.rocm_aiter_act_mul_and_fp8_group_quant.default
|
||||||
|
|
||||||
|
|
||||||
|
class AiterRMSFp8GroupQuantPattern:
|
||||||
|
"""
|
||||||
|
This pattern fuses aiter rms_norm & group fp8 quant custom
|
||||||
|
ops into an aiter rms_norm_group_fp8_quant op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.quant_dtype = quant_dtype
|
||||||
|
self.quant_op = quant_op
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
at1 = AITER_RMS_OP(x=input, weight=weight, variance_epsilon=self.epsilon)
|
||||||
|
|
||||||
|
at2 = self.quant_op(at1, 128)
|
||||||
|
|
||||||
|
return at2[0], at2[1]
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
at = AITER_RMS_GROUP_QUANT_OP(
|
||||||
|
x=input,
|
||||||
|
weight=weight,
|
||||||
|
variance_epsilon=self.epsilon,
|
||||||
|
group_size=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
return at[0], at[1]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
empty_bf16(5, 4), # input
|
||||||
|
empty_bf16(1, 5), # weight
|
||||||
|
]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class AiterFusedAddRMSFp8GroupQuantPattern:
|
||||||
|
"""
|
||||||
|
This pattern fuses aiter rms_norm_with_add & group fp8 quant custom ops
|
||||||
|
into a aiter rms_norm_with_add_group_fp8_quant op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, quant_dtype: torch.dtype, quant_op: OpOverload):
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.quant_dtype = quant_dtype
|
||||||
|
self.quant_op = quant_op
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
at1 = AITER_RMS_ADD_OP(
|
||||||
|
x=input,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
variance_epsilon=self.epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
at2 = self.quant_op(at1[0], 128)
|
||||||
|
|
||||||
|
# result, scale, residual
|
||||||
|
return at2[0], at2[1], at1[1]
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
):
|
||||||
|
at = AITER_RMS_ADD_GROUP_QUANT_OP(
|
||||||
|
x=input,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
variance_epsilon=self.epsilon,
|
||||||
|
group_size=128,
|
||||||
|
)
|
||||||
|
|
||||||
|
# result, scale, residual
|
||||||
|
return at[0], at[1], at[2]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
empty_bf16(5, 4), # input
|
||||||
|
empty_bf16(5, 4), # residual
|
||||||
|
empty_bf16(1, 5), # weight
|
||||||
|
]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class RocmAiterRMSNormFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||||
|
"""
|
||||||
|
This pass fuses rms_norm & quant custom ops into a fused rms_norm_quant op.
|
||||||
|
It also supports fused_add_rms_norm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
|
def __init__(self, config: VllmConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
|
pass_name="rocm_aiter_rms_norm_fp8_group_quant_fusion_pass"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure fused add patterns are before simple rms norm,
|
||||||
|
# as the latter is a subset of the former in torch ops
|
||||||
|
for epsilon in [1e-5, 1e-6]:
|
||||||
|
# Fuse rms_norm + dynamic group fp8 quant
|
||||||
|
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
||||||
|
AiterRMSFp8GroupQuantPattern(epsilon, FP8_DTYPE, quant_op).register(
|
||||||
|
self.patterns
|
||||||
|
)
|
||||||
|
|
||||||
|
AiterFusedAddRMSFp8GroupQuantPattern(
|
||||||
|
epsilon, FP8_DTYPE, quant_op
|
||||||
|
).register(self.patterns)
|
||||||
|
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
|
def __call__(self, graph: fx.Graph):
|
||||||
|
self.matched_count = self.patterns.apply(graph)
|
||||||
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
|
|
||||||
|
def uuid(self) -> Any:
|
||||||
|
fusion_patterns = [
|
||||||
|
AiterRMSFp8GroupQuantPattern,
|
||||||
|
AiterFusedAddRMSFp8GroupQuantPattern,
|
||||||
|
]
|
||||||
|
return self.hash_source(self, *fusion_patterns)
|
||||||
|
|
||||||
|
|
||||||
|
class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
|
||||||
|
"""
|
||||||
|
This pattern fuses aiter silu_and_mul & group fp8 quant custom
|
||||||
|
ops into an aiter silu_and_mul_group_fp8_quant op.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_op: OpOverload):
|
||||||
|
self.silu_and_mul_matcher = MatcherSiluAndMul()
|
||||||
|
self.quant_op = quant_op
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
def pattern(
|
||||||
|
input: torch.Tensor,
|
||||||
|
):
|
||||||
|
at1 = self.silu_and_mul_matcher(input)
|
||||||
|
at2 = self.quant_op(at1, 128)
|
||||||
|
return at2[0], at2[1]
|
||||||
|
|
||||||
|
def replacement(
|
||||||
|
input: torch.Tensor,
|
||||||
|
):
|
||||||
|
at = FUSED_SILU_MUL_QUANT_OP(x=input, group_size=128)
|
||||||
|
return at[0], at[1]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
self.silu_and_mul_matcher.inputs()[0],
|
||||||
|
]
|
||||||
|
|
||||||
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
|
||||||
|
"""
|
||||||
|
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||||
|
It uses the torch pattern matcher to find the patterns and replace them.
|
||||||
|
|
||||||
|
Because patterns can only be registered once, the pass is a singleton.
|
||||||
|
This will be addressed in a future version of PyTorch:
|
||||||
|
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||||
|
"""
|
||||||
|
|
||||||
|
@enable_fake_mode
|
||||||
|
def __init__(self, config: VllmConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
|
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
|
||||||
|
)
|
||||||
|
|
||||||
|
for quant_op in [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]:
|
||||||
|
AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
|
||||||
|
|
||||||
|
self.dump_patterns(config, self.patterns)
|
||||||
|
|
||||||
|
@VllmInductorPass.time_and_log
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.matched_count = self.patterns.apply(graph)
|
||||||
|
logger.debug("Replaced %s patterns", self.matched_count)
|
||||||
|
|
||||||
|
def uuid(self):
|
||||||
|
fusion_patterns = [
|
||||||
|
ActivationQuantPattern,
|
||||||
|
AiterSiluMulFp8GroupQuantPattern,
|
||||||
|
]
|
||||||
|
return VllmInductorPass.hash_source(self, *fusion_patterns)
|
||||||
@ -17,7 +17,6 @@ from vllm.config.utils import (
|
|||||||
Range,
|
Range,
|
||||||
config,
|
config,
|
||||||
get_hash_factors,
|
get_hash_factors,
|
||||||
handle_deprecated,
|
|
||||||
hash_factors,
|
hash_factors,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -127,27 +126,6 @@ class PassConfig:
|
|||||||
fuse_allreduce_rms: bool = Field(default=None)
|
fuse_allreduce_rms: bool = Field(default=None)
|
||||||
"""Enable flashinfer allreduce fusion."""
|
"""Enable flashinfer allreduce fusion."""
|
||||||
|
|
||||||
# Deprecated flags
|
|
||||||
enable_fusion: bool = Field(default=None)
|
|
||||||
"""Deprecated in: v0.12.0. Use fuse_norm_quant and fuse_act_quant
|
|
||||||
instead. Will be removed in v0.13.0 or v1.0.0, whichever is sooner.
|
|
||||||
"""
|
|
||||||
enable_attn_fusion: bool = Field(default=None)
|
|
||||||
"""Deprecated in: v0.12.0. Use fuse_attn_quant instead.
|
|
||||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
|
||||||
enable_noop: bool = Field(default=None)
|
|
||||||
"""Deprecated in: v0.12.0. Use eliminate_noops instead.
|
|
||||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
|
||||||
enable_sequence_parallelism: bool = Field(default=None)
|
|
||||||
"""Deprecated in: v0.12.0. Use enable_sp instead.
|
|
||||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
|
||||||
enable_async_tp: bool = Field(default=None)
|
|
||||||
"""Deprecated in: v0.12.0. Use fuse_gemm_comms instead.
|
|
||||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
|
||||||
enable_fi_allreduce_fusion: bool = Field(default=None)
|
|
||||||
"""Deprecated in: v0.12.0. Use fuse_allreduce_rms instead.
|
|
||||||
Will be removed in v0.13.0 or v1.0.0, whichever is sooner."""
|
|
||||||
|
|
||||||
fi_allreduce_fusion_max_size_mb: float | None = None
|
fi_allreduce_fusion_max_size_mb: float | None = None
|
||||||
"""The threshold of the communicated tensor sizes under which
|
"""The threshold of the communicated tensor sizes under which
|
||||||
vllm should use flashinfer fused allreduce. Specified as a
|
vllm should use flashinfer fused allreduce. Specified as a
|
||||||
@ -206,15 +184,7 @@ class PassConfig:
|
|||||||
Any future fields that don't affect compilation should be excluded.
|
Any future fields that don't affect compilation should be excluded.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ignored_fields = [
|
return hash_factors(get_hash_factors(self, set()))
|
||||||
"enable_fusion",
|
|
||||||
"enable_attn_fusion",
|
|
||||||
"enable_noop",
|
|
||||||
"enable_sequence_parallelism",
|
|
||||||
"enable_async_tp",
|
|
||||||
"enable_fi_allreduce_fusion",
|
|
||||||
]
|
|
||||||
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
|
|
||||||
|
|
||||||
@field_validator(
|
@field_validator(
|
||||||
"fuse_norm_quant",
|
"fuse_norm_quant",
|
||||||
@ -224,12 +194,6 @@ class PassConfig:
|
|||||||
"enable_sp",
|
"enable_sp",
|
||||||
"fuse_gemm_comms",
|
"fuse_gemm_comms",
|
||||||
"fuse_allreduce_rms",
|
"fuse_allreduce_rms",
|
||||||
"enable_fusion",
|
|
||||||
"enable_attn_fusion",
|
|
||||||
"enable_noop",
|
|
||||||
"enable_sequence_parallelism",
|
|
||||||
"enable_async_tp",
|
|
||||||
"enable_fi_allreduce_fusion",
|
|
||||||
mode="wrap",
|
mode="wrap",
|
||||||
)
|
)
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -242,49 +206,6 @@ class PassConfig:
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Handle deprecation and defaults
|
# Handle deprecation and defaults
|
||||||
|
|
||||||
# Map old flags to new flags and issue warnings
|
|
||||||
handle_deprecated(
|
|
||||||
self,
|
|
||||||
"enable_fusion",
|
|
||||||
["fuse_norm_quant", "fuse_act_quant"],
|
|
||||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
|
||||||
)
|
|
||||||
|
|
||||||
handle_deprecated(
|
|
||||||
self,
|
|
||||||
"enable_attn_fusion",
|
|
||||||
"fuse_attn_quant",
|
|
||||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
|
||||||
)
|
|
||||||
|
|
||||||
handle_deprecated(
|
|
||||||
self,
|
|
||||||
"enable_sequence_parallelism",
|
|
||||||
"enable_sp",
|
|
||||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
|
||||||
)
|
|
||||||
|
|
||||||
handle_deprecated(
|
|
||||||
self,
|
|
||||||
"enable_async_tp",
|
|
||||||
"fuse_gemm_comms",
|
|
||||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
|
||||||
)
|
|
||||||
|
|
||||||
handle_deprecated(
|
|
||||||
self,
|
|
||||||
"enable_fi_allreduce_fusion",
|
|
||||||
"fuse_allreduce_rms",
|
|
||||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
|
||||||
)
|
|
||||||
|
|
||||||
handle_deprecated(
|
|
||||||
self,
|
|
||||||
"enable_noop",
|
|
||||||
"eliminate_noops",
|
|
||||||
"v0.13.0 or v1.0.0, whichever is sooner",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.eliminate_noops:
|
if not self.eliminate_noops:
|
||||||
if self.fuse_norm_quant or self.fuse_act_quant:
|
if self.fuse_norm_quant or self.fuse_act_quant:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
|
|||||||
@ -64,6 +64,11 @@ class KVTransferConfig:
|
|||||||
enable_permute_local_kv: bool = False
|
enable_permute_local_kv: bool = False
|
||||||
"""Experiment feature flag to enable HND to NHD KV Transfer"""
|
"""Experiment feature flag to enable HND to NHD KV Transfer"""
|
||||||
|
|
||||||
|
kv_load_failure_policy: Literal["recompute", "fail"] = "recompute"
|
||||||
|
"""Policy for handling KV cache load failures.
|
||||||
|
'recompute': reschedule the request to recompute failed blocks (default)
|
||||||
|
'fail': immediately fail the request with an error finish reason"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
|||||||
@ -73,17 +73,6 @@ logger = init_logger(__name__)
|
|||||||
RunnerOption = Literal["auto", RunnerType]
|
RunnerOption = Literal["auto", RunnerType]
|
||||||
ConvertType = Literal["none", "embed", "classify", "reward"]
|
ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||||
ConvertOption = Literal["auto", ConvertType]
|
ConvertOption = Literal["auto", ConvertType]
|
||||||
TaskOption = Literal[
|
|
||||||
"auto",
|
|
||||||
"generate",
|
|
||||||
"embedding",
|
|
||||||
"embed",
|
|
||||||
"classify",
|
|
||||||
"score",
|
|
||||||
"reward",
|
|
||||||
"transcription",
|
|
||||||
"draft",
|
|
||||||
]
|
|
||||||
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
|
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
|
||||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||||
LogprobsMode = Literal[
|
LogprobsMode = Literal[
|
||||||
@ -93,12 +82,6 @@ HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
|
|||||||
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
|
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
|
||||||
LayerBlockType = Literal["attention", "linear_attention", "mamba"]
|
LayerBlockType = Literal["attention", "linear_attention", "mamba"]
|
||||||
|
|
||||||
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
|
|
||||||
"generate": ["generate", "transcription"],
|
|
||||||
"pooling": ["embedding", "embed", "classify", "score", "reward"],
|
|
||||||
"draft": ["draft"],
|
|
||||||
}
|
|
||||||
|
|
||||||
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
|
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
|
||||||
"generate": [],
|
"generate": [],
|
||||||
"pooling": ["embed", "classify", "reward"],
|
"pooling": ["embed", "classify", "reward"],
|
||||||
@ -126,12 +109,6 @@ class ModelConfig:
|
|||||||
"""Convert the model using adapters defined in
|
"""Convert the model using adapters defined in
|
||||||
[vllm.model_executor.models.adapters][]. The most common use case is to
|
[vllm.model_executor.models.adapters][]. The most common use case is to
|
||||||
adapt a text generation model to be used for pooling tasks."""
|
adapt a text generation model to be used for pooling tasks."""
|
||||||
task: TaskOption | None = None
|
|
||||||
"""[DEPRECATED] The task to use the model for. If the model supports more
|
|
||||||
than one model runner, this is used to select which model runner to run.
|
|
||||||
|
|
||||||
Note that the model may support other tasks using the same model runner.
|
|
||||||
"""
|
|
||||||
tokenizer: SkipValidation[str] = None # type: ignore
|
tokenizer: SkipValidation[str] = None # type: ignore
|
||||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||||
name or path will be used."""
|
name or path will be used."""
|
||||||
@ -335,7 +312,6 @@ class ModelConfig:
|
|||||||
ignored_factors = {
|
ignored_factors = {
|
||||||
"runner",
|
"runner",
|
||||||
"convert",
|
"convert",
|
||||||
"task",
|
|
||||||
"tokenizer",
|
"tokenizer",
|
||||||
"tokenizer_mode",
|
"tokenizer_mode",
|
||||||
"seed",
|
"seed",
|
||||||
@ -510,97 +486,6 @@ class ModelConfig:
|
|||||||
is_generative_model = registry.is_text_generation_model(architectures, self)
|
is_generative_model = registry.is_text_generation_model(architectures, self)
|
||||||
is_pooling_model = registry.is_pooling_model(architectures, self)
|
is_pooling_model = registry.is_pooling_model(architectures, self)
|
||||||
|
|
||||||
def _task_to_convert(task: TaskOption) -> ConvertType:
|
|
||||||
if task == "embedding" or task == "embed":
|
|
||||||
return "embed"
|
|
||||||
if task == "classify":
|
|
||||||
return "classify"
|
|
||||||
if task == "reward":
|
|
||||||
logger.warning(
|
|
||||||
"Pooling models now default support all pooling; "
|
|
||||||
"you can use it without any settings."
|
|
||||||
)
|
|
||||||
return "embed"
|
|
||||||
if task == "score":
|
|
||||||
new_task = self._get_default_pooling_task(architectures)
|
|
||||||
return "classify" if new_task == "classify" else "embed"
|
|
||||||
|
|
||||||
return "none"
|
|
||||||
|
|
||||||
if self.task is not None:
|
|
||||||
runner: RunnerOption = "auto"
|
|
||||||
convert: ConvertOption = "auto"
|
|
||||||
msg_prefix = (
|
|
||||||
"The 'task' option has been deprecated and will be "
|
|
||||||
"removed in v0.13.0 or v1.0, whichever comes first."
|
|
||||||
)
|
|
||||||
msg_hint = "Please remove this option."
|
|
||||||
|
|
||||||
is_generative_task = self.task in _RUNNER_TASKS["generate"]
|
|
||||||
is_pooling_task = self.task in _RUNNER_TASKS["pooling"]
|
|
||||||
|
|
||||||
if is_generative_model and is_pooling_model:
|
|
||||||
if is_generative_task:
|
|
||||||
runner = "generate"
|
|
||||||
convert = "auto"
|
|
||||||
msg_hint = (
|
|
||||||
"Please replace this option with `--runner "
|
|
||||||
"generate` to continue using this model "
|
|
||||||
"as a generative model."
|
|
||||||
)
|
|
||||||
elif is_pooling_task:
|
|
||||||
runner = "pooling"
|
|
||||||
convert = "auto"
|
|
||||||
msg_hint = (
|
|
||||||
"Please replace this option with `--runner "
|
|
||||||
"pooling` to continue using this model "
|
|
||||||
"as a pooling model."
|
|
||||||
)
|
|
||||||
else: # task == "auto"
|
|
||||||
pass
|
|
||||||
elif is_generative_model or is_pooling_model:
|
|
||||||
if is_generative_task:
|
|
||||||
runner = "generate"
|
|
||||||
convert = "auto"
|
|
||||||
msg_hint = "Please remove this option"
|
|
||||||
elif is_pooling_task:
|
|
||||||
runner = "pooling"
|
|
||||||
convert = _task_to_convert(self.task)
|
|
||||||
msg_hint = (
|
|
||||||
"Please replace this option with `--convert "
|
|
||||||
f"{convert}` to continue using this model "
|
|
||||||
"as a pooling model."
|
|
||||||
)
|
|
||||||
else: # task == "auto"
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Neither generative nor pooling model - try to convert if possible
|
|
||||||
if is_pooling_task:
|
|
||||||
runner = "pooling"
|
|
||||||
convert = _task_to_convert(self.task)
|
|
||||||
msg_hint = (
|
|
||||||
"Please replace this option with `--runner pooling "
|
|
||||||
f"--convert {convert}` to continue using this model "
|
|
||||||
"as a pooling model."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
debug_info = {
|
|
||||||
"architectures": architectures,
|
|
||||||
"is_generative_model": is_generative_model,
|
|
||||||
"is_pooling_model": is_pooling_model,
|
|
||||||
}
|
|
||||||
raise AssertionError(
|
|
||||||
"The model should be a generative or "
|
|
||||||
"pooling model when task is set to "
|
|
||||||
f"{self.task!r}. Found: {debug_info}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.runner = runner
|
|
||||||
self.convert = convert
|
|
||||||
|
|
||||||
msg = f"{msg_prefix} {msg_hint}"
|
|
||||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
|
||||||
|
|
||||||
self.runner_type = self._get_runner_type(architectures, self.runner)
|
self.runner_type = self._get_runner_type(architectures, self.runner)
|
||||||
self.convert_type = self._get_convert_type(
|
self.convert_type = self._get_convert_type(
|
||||||
architectures, self.runner_type, self.convert
|
architectures, self.runner_type, self.convert
|
||||||
@ -903,6 +788,13 @@ class ModelConfig:
|
|||||||
runner_type: RunnerType,
|
runner_type: RunnerType,
|
||||||
convert: ConvertOption,
|
convert: ConvertOption,
|
||||||
) -> ConvertType:
|
) -> ConvertType:
|
||||||
|
if convert == "reward":
|
||||||
|
logger.warning(
|
||||||
|
"`--convert reward` is deprecated and will be removed in v0.15. "
|
||||||
|
"Please use `--convert embed` instead."
|
||||||
|
)
|
||||||
|
return "embed"
|
||||||
|
|
||||||
if convert != "auto":
|
if convert != "auto":
|
||||||
return convert
|
return convert
|
||||||
|
|
||||||
@ -918,22 +810,6 @@ class ModelConfig:
|
|||||||
|
|
||||||
return convert_type
|
return convert_type
|
||||||
|
|
||||||
def _get_default_pooling_task(
|
|
||||||
self,
|
|
||||||
architectures: list[str],
|
|
||||||
) -> Literal["embed", "classify", "reward"]:
|
|
||||||
if self.registry.is_cross_encoder_model(architectures, self):
|
|
||||||
return "classify"
|
|
||||||
|
|
||||||
for arch in architectures:
|
|
||||||
match = try_match_architecture_defaults(arch, runner_type="pooling")
|
|
||||||
if match:
|
|
||||||
_, (_, convert_type) = match
|
|
||||||
assert convert_type != "none"
|
|
||||||
return convert_type
|
|
||||||
|
|
||||||
return "embed"
|
|
||||||
|
|
||||||
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
|
def _parse_quant_hf_config(self, hf_config: PretrainedConfig):
|
||||||
quant_cfg = getattr(hf_config, "quantization_config", None)
|
quant_cfg = getattr(hf_config, "quantization_config", None)
|
||||||
if quant_cfg is None:
|
if quant_cfg is None:
|
||||||
|
|||||||
@ -321,11 +321,6 @@ class ParallelConfig:
|
|||||||
"num_redundant_experts."
|
"num_redundant_experts."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.prefill_context_parallel_size > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Prefill context parallelism is not fully supported. "
|
|
||||||
"Please set prefill_context_parallel_size to 1."
|
|
||||||
)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -111,13 +111,15 @@ class PoolerConfig:
|
|||||||
def get_use_activation(o: object):
|
def get_use_activation(o: object):
|
||||||
if softmax := getattr(o, "softmax", None) is not None:
|
if softmax := getattr(o, "softmax", None) is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"softmax will be deprecated, please use use_activation instead."
|
"softmax will be deprecated and will be removed in v0.15. "
|
||||||
|
"Please use use_activation instead."
|
||||||
)
|
)
|
||||||
return softmax
|
return softmax
|
||||||
|
|
||||||
if activation := getattr(o, "activation", None) is not None:
|
if activation := getattr(o, "activation", None) is not None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"activation will be deprecated, please use use_activation instead."
|
"activation will be deprecated and will be removed in v0.15. "
|
||||||
|
"Please use use_activation instead."
|
||||||
)
|
)
|
||||||
return activation
|
return activation
|
||||||
|
|
||||||
|
|||||||
@ -666,8 +666,9 @@ class VllmConfig:
|
|||||||
|
|
||||||
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
default_config = OPTIMIZATION_LEVEL_TO_CONFIG[self.optimization_level]
|
||||||
self._apply_optimization_level_defaults(default_config)
|
self._apply_optimization_level_defaults(default_config)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
|
||||||
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
|
and self.compilation_config.mode != CompilationMode.VLLM_COMPILE
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
@ -692,22 +693,29 @@ class VllmConfig:
|
|||||||
|
|
||||||
if current_platform.support_static_graph_mode():
|
if current_platform.support_static_graph_mode():
|
||||||
# if cudagraph_mode has full cudagraphs, we need to check support
|
# if cudagraph_mode has full cudagraphs, we need to check support
|
||||||
if (
|
if model_config := self.model_config:
|
||||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
if (
|
||||||
and self.model_config is not None
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
):
|
and model_config.pooler_config is not None
|
||||||
if self.model_config.pooler_config is not None:
|
):
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Pooling models do not support full cudagraphs. "
|
"Pooling models do not support full cudagraphs. "
|
||||||
"Overriding cudagraph_mode to PIECEWISE."
|
"Overriding cudagraph_mode to PIECEWISE."
|
||||||
)
|
)
|
||||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||||
elif self.model_config.is_encoder_decoder:
|
elif (
|
||||||
logger.warning_once(
|
model_config.is_encoder_decoder
|
||||||
"Encoder-decoder models do not support full cudagraphs. "
|
and self.compilation_config.cudagraph_mode
|
||||||
"Overriding cudagraph_mode to PIECEWISE."
|
not in (CUDAGraphMode.NONE, CUDAGraphMode.FULL_DECODE_ONLY)
|
||||||
|
):
|
||||||
|
logger.info_once(
|
||||||
|
"Encoder-decoder models do not support %s. "
|
||||||
|
"Overriding cudagraph_mode to FULL_DECODE_ONLY.",
|
||||||
|
self.compilation_config.cudagraph_mode.name,
|
||||||
|
)
|
||||||
|
self.compilation_config.cudagraph_mode = (
|
||||||
|
CUDAGraphMode.FULL_DECODE_ONLY
|
||||||
)
|
)
|
||||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
|
||||||
|
|
||||||
# disable cudagraph when enforce eager execution
|
# disable cudagraph when enforce eager execution
|
||||||
if self.model_config is not None and self.model_config.enforce_eager:
|
if self.model_config is not None and self.model_config.enforce_eager:
|
||||||
@ -812,11 +820,6 @@ class VllmConfig:
|
|||||||
f"({self.parallel_config.cp_kv_cache_interleave_size})."
|
f"({self.parallel_config.cp_kv_cache_interleave_size})."
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
|
||||||
self.parallel_config.cp_kv_cache_interleave_size == 1
|
|
||||||
or self.speculative_config is None
|
|
||||||
), "MTP with cp_kv_cache_interleave_size > 1 is not supported now."
|
|
||||||
|
|
||||||
# Do this after all the updates to compilation_config.mode
|
# Do this after all the updates to compilation_config.mode
|
||||||
self.compilation_config.set_splitting_ops_for_v1(
|
self.compilation_config.set_splitting_ops_for_v1(
|
||||||
all2all_backend=self.parallel_config.all2all_backend,
|
all2all_backend=self.parallel_config.all2all_backend,
|
||||||
@ -1006,7 +1009,7 @@ class VllmConfig:
|
|||||||
max_graph_size = min(max_num_seqs * 2, 512)
|
max_graph_size = min(max_num_seqs * 2, 512)
|
||||||
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
|
# 1, 2, 4, then multiples of 8 up to 256 and then multiples of 16
|
||||||
# up to max_graph_size
|
# up to max_graph_size
|
||||||
cuda_graph_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
|
cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list(
|
||||||
range(256, max_graph_size + 1, 16))
|
range(256, max_graph_size + 1, 16))
|
||||||
|
|
||||||
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
|
In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
|
||||||
@ -1047,8 +1050,14 @@ class VllmConfig:
|
|||||||
self.compilation_config.max_cudagraph_capture_size
|
self.compilation_config.max_cudagraph_capture_size
|
||||||
)
|
)
|
||||||
if max_cudagraph_capture_size is None:
|
if max_cudagraph_capture_size is None:
|
||||||
|
decode_query_len = 1
|
||||||
|
if (
|
||||||
|
self.speculative_config
|
||||||
|
and self.speculative_config.num_speculative_tokens
|
||||||
|
):
|
||||||
|
decode_query_len += self.speculative_config.num_speculative_tokens
|
||||||
max_cudagraph_capture_size = min(
|
max_cudagraph_capture_size = min(
|
||||||
self.scheduler_config.max_num_seqs * 2, 512
|
self.scheduler_config.max_num_seqs * decode_query_len * 2, 512
|
||||||
)
|
)
|
||||||
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
|
max_cudagraph_capture_size = min(max_num_tokens, max_cudagraph_capture_size)
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
import functools
|
||||||
import pickle
|
import pickle
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -43,6 +44,33 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
|||||||
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
|
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")
|
||||||
|
|
||||||
|
|
||||||
|
# Memory fence for cross-process shared memory visibility.
|
||||||
|
# Required for correct producer-consumer synchronization when using
|
||||||
|
# shared memory without locks.
|
||||||
|
_memory_fence_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def memory_fence():
|
||||||
|
"""
|
||||||
|
Full memory barrier for shared memory synchronization.
|
||||||
|
|
||||||
|
Ensures all prior memory writes are visible to other processes before
|
||||||
|
any subsequent reads. This is critical for lock-free producer-consumer
|
||||||
|
patterns using shared memory.
|
||||||
|
|
||||||
|
Implementation acquires and immediately releases a lock. Python's
|
||||||
|
threading.Lock provides sequentially consistent memory barrier semantics
|
||||||
|
across all major platforms (POSIX, Windows). This is a lightweight
|
||||||
|
operation (~20ns) that guarantees:
|
||||||
|
- All stores before the barrier are visible to other threads/processes
|
||||||
|
- All loads after the barrier see the latest values
|
||||||
|
"""
|
||||||
|
# Lock acquire/release provides full memory barrier semantics.
|
||||||
|
# Using context manager ensures lock release even on exceptions.
|
||||||
|
with _memory_fence_lock:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def to_bytes_big(value: int, size: int) -> bytes:
|
def to_bytes_big(value: int, size: int) -> bytes:
|
||||||
return value.to_bytes(size, byteorder="big")
|
return value.to_bytes(size, byteorder="big")
|
||||||
|
|
||||||
@ -414,6 +442,10 @@ class MessageQueue:
|
|||||||
n_warning = 1
|
n_warning = 1
|
||||||
while True:
|
while True:
|
||||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||||
|
# Memory fence ensures we see the latest read flags from readers.
|
||||||
|
# Without this, we may read stale flags from our CPU cache and
|
||||||
|
# spin indefinitely even though readers have completed.
|
||||||
|
memory_fence()
|
||||||
read_count = sum(metadata_buffer[1:])
|
read_count = sum(metadata_buffer[1:])
|
||||||
written_flag = metadata_buffer[0]
|
written_flag = metadata_buffer[0]
|
||||||
if written_flag and read_count != self.buffer.n_reader:
|
if written_flag and read_count != self.buffer.n_reader:
|
||||||
@ -458,6 +490,10 @@ class MessageQueue:
|
|||||||
metadata_buffer[i] = 0
|
metadata_buffer[i] = 0
|
||||||
# mark the block as written
|
# mark the block as written
|
||||||
metadata_buffer[0] = 1
|
metadata_buffer[0] = 1
|
||||||
|
# Memory fence ensures the write is visible to readers on other cores
|
||||||
|
# before we proceed. Without this, readers may spin indefinitely
|
||||||
|
# waiting for a write that's stuck in our CPU's store buffer.
|
||||||
|
memory_fence()
|
||||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -473,6 +509,10 @@ class MessageQueue:
|
|||||||
n_warning = 1
|
n_warning = 1
|
||||||
while True:
|
while True:
|
||||||
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
|
||||||
|
# Memory fence ensures we see the latest writes from the writer.
|
||||||
|
# Without this, we may read stale flags from our CPU cache
|
||||||
|
# and spin indefinitely even though writer has updated them.
|
||||||
|
memory_fence()
|
||||||
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
read_flag = metadata_buffer[self.local_reader_rank + 1]
|
||||||
written_flag = metadata_buffer[0]
|
written_flag = metadata_buffer[0]
|
||||||
if not written_flag or read_flag:
|
if not written_flag or read_flag:
|
||||||
@ -513,6 +553,10 @@ class MessageQueue:
|
|||||||
# caller has read from the buffer
|
# caller has read from the buffer
|
||||||
# set the read flag
|
# set the read flag
|
||||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||||
|
# Memory fence ensures the read flag is visible to the writer.
|
||||||
|
# Without this, writer may not see our read completion and
|
||||||
|
# could wait indefinitely for all readers to finish.
|
||||||
|
memory_fence()
|
||||||
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
|
||||||
|
|
||||||
self._read_spin_timer.record_activity()
|
self._read_spin_timer.record_activity()
|
||||||
|
|||||||
@ -491,6 +491,9 @@ async def transfer_layer(
|
|||||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
||||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
||||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||||
|
# A buffer to hold the expert weights in one layer during the exchange.
|
||||||
|
# NOTE: Currently we assume the same weights across different layers
|
||||||
|
# have the same shape.
|
||||||
|
|
||||||
old_global_expert_indices_np = old_global_expert_indices.cpu().numpy()
|
old_global_expert_indices_np = old_global_expert_indices.cpu().numpy()
|
||||||
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy()
|
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy()
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
|
|||||||
LMCacheAsyncLookupServer,
|
LMCacheAsyncLookupServer,
|
||||||
)
|
)
|
||||||
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
|
from lmcache.v1.offload_server.zmq_server import ZMQOffloadServer
|
||||||
from lmcache.v1.plugin.plugin_launcher import PluginLauncher
|
from lmcache.v1.plugin.runtime_plugin_launcher import RuntimePluginLauncher
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -683,7 +683,7 @@ class LMCacheConnectorV1Impl:
|
|||||||
self.api_server = InternalAPIServer(self)
|
self.api_server = InternalAPIServer(self)
|
||||||
self.api_server.start()
|
self.api_server.start()
|
||||||
# Launch plugins
|
# Launch plugins
|
||||||
self.plugin_launcher = PluginLauncher(
|
self.plugin_launcher = RuntimePluginLauncher(
|
||||||
self.config,
|
self.config,
|
||||||
role,
|
role,
|
||||||
self.worker_count,
|
self.worker_count,
|
||||||
|
|||||||
@ -1586,6 +1586,8 @@ def destroy_distributed_environment():
|
|||||||
|
|
||||||
|
|
||||||
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||||
|
# Reset environment variable cache
|
||||||
|
envs.disable_envs_cache()
|
||||||
# Ensure all objects are not frozen before cleanup
|
# Ensure all objects are not frozen before cleanup
|
||||||
gc.unfreeze()
|
gc.unfreeze()
|
||||||
|
|
||||||
|
|||||||
@ -71,7 +71,6 @@ from vllm.config.model import (
|
|||||||
LogprobsMode,
|
LogprobsMode,
|
||||||
ModelDType,
|
ModelDType,
|
||||||
RunnerOption,
|
RunnerOption,
|
||||||
TaskOption,
|
|
||||||
TokenizerMode,
|
TokenizerMode,
|
||||||
)
|
)
|
||||||
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
|
from vllm.config.multimodal import MMCacheType, MMEncoderTPMode
|
||||||
@ -360,7 +359,6 @@ class EngineArgs:
|
|||||||
hf_config_path: str | None = ModelConfig.hf_config_path
|
hf_config_path: str | None = ModelConfig.hf_config_path
|
||||||
runner: RunnerOption = ModelConfig.runner
|
runner: RunnerOption = ModelConfig.runner
|
||||||
convert: ConvertOption = ModelConfig.convert
|
convert: ConvertOption = ModelConfig.convert
|
||||||
task: TaskOption | None = ModelConfig.task
|
|
||||||
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
|
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
|
||||||
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
|
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
|
||||||
tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
|
tokenizer_mode: TokenizerMode | str = ModelConfig.tokenizer_mode
|
||||||
@ -373,9 +371,8 @@ class EngineArgs:
|
|||||||
config_format: str = ModelConfig.config_format
|
config_format: str = ModelConfig.config_format
|
||||||
dtype: ModelDType = ModelConfig.dtype
|
dtype: ModelDType = ModelConfig.dtype
|
||||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||||
seed: int | None = 0
|
seed: int = ModelConfig.seed
|
||||||
max_model_len: int | None = ModelConfig.max_model_len
|
max_model_len: int | None = ModelConfig.max_model_len
|
||||||
cuda_graph_sizes: list[int] | None = CompilationConfig.cudagraph_capture_sizes
|
|
||||||
cudagraph_capture_sizes: list[int] | None = (
|
cudagraph_capture_sizes: list[int] | None = (
|
||||||
CompilationConfig.cudagraph_capture_sizes
|
CompilationConfig.cudagraph_capture_sizes
|
||||||
)
|
)
|
||||||
@ -463,7 +460,6 @@ class EngineArgs:
|
|||||||
MultiModalConfig, "media_io_kwargs"
|
MultiModalConfig, "media_io_kwargs"
|
||||||
)
|
)
|
||||||
mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
|
mm_processor_kwargs: dict[str, Any] | None = MultiModalConfig.mm_processor_kwargs
|
||||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
|
||||||
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
||||||
mm_processor_cache_type: MMCacheType | None = (
|
mm_processor_cache_type: MMCacheType | None = (
|
||||||
MultiModalConfig.mm_processor_cache_type
|
MultiModalConfig.mm_processor_cache_type
|
||||||
@ -559,9 +555,6 @@ class EngineArgs:
|
|||||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||||
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||||
|
|
||||||
# DEPRECATED
|
|
||||||
enable_multimodal_encoder_data_parallel: bool = False
|
|
||||||
|
|
||||||
logits_processors: list[str | type[LogitsProcessor]] | None = (
|
logits_processors: list[str | type[LogitsProcessor]] | None = (
|
||||||
ModelConfig.logits_processors
|
ModelConfig.logits_processors
|
||||||
)
|
)
|
||||||
@ -629,7 +622,6 @@ class EngineArgs:
|
|||||||
model_group.add_argument("--model", **model_kwargs["model"])
|
model_group.add_argument("--model", **model_kwargs["model"])
|
||||||
model_group.add_argument("--runner", **model_kwargs["runner"])
|
model_group.add_argument("--runner", **model_kwargs["runner"])
|
||||||
model_group.add_argument("--convert", **model_kwargs["convert"])
|
model_group.add_argument("--convert", **model_kwargs["convert"])
|
||||||
model_group.add_argument("--task", **model_kwargs["task"], deprecated=True)
|
|
||||||
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
|
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
|
||||||
model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
|
model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"])
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
@ -883,11 +875,6 @@ class EngineArgs:
|
|||||||
parallel_group.add_argument(
|
parallel_group.add_argument(
|
||||||
"--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
|
"--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
|
||||||
)
|
)
|
||||||
parallel_group.add_argument(
|
|
||||||
"--enable-multimodal-encoder-data-parallel",
|
|
||||||
action="store_true",
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
cache_kwargs = get_kwargs(CacheConfig)
|
cache_kwargs = get_kwargs(CacheConfig)
|
||||||
@ -961,9 +948,6 @@ class EngineArgs:
|
|||||||
multimodal_group.add_argument(
|
multimodal_group.add_argument(
|
||||||
"--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
|
"--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"]
|
||||||
)
|
)
|
||||||
multimodal_group.add_argument(
|
|
||||||
"--disable-mm-preprocessor-cache", action="store_true", deprecated=True
|
|
||||||
)
|
|
||||||
multimodal_group.add_argument(
|
multimodal_group.add_argument(
|
||||||
"--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
|
"--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"]
|
||||||
)
|
)
|
||||||
@ -1121,15 +1105,6 @@ class EngineArgs:
|
|||||||
compilation_group.add_argument(
|
compilation_group.add_argument(
|
||||||
"--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
|
"--cudagraph-capture-sizes", **compilation_kwargs["cudagraph_capture_sizes"]
|
||||||
)
|
)
|
||||||
compilation_kwargs["cudagraph_capture_sizes"]["help"] = (
|
|
||||||
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0,"
|
|
||||||
" whichever is soonest. Please use --cudagraph-capture-sizes instead."
|
|
||||||
)
|
|
||||||
compilation_group.add_argument(
|
|
||||||
"--cuda-graph-sizes",
|
|
||||||
**compilation_kwargs["cudagraph_capture_sizes"],
|
|
||||||
deprecated=True,
|
|
||||||
)
|
|
||||||
compilation_group.add_argument(
|
compilation_group.add_argument(
|
||||||
"--max-cudagraph-capture-size",
|
"--max-cudagraph-capture-size",
|
||||||
**compilation_kwargs["max_cudagraph_capture_size"],
|
**compilation_kwargs["max_cudagraph_capture_size"],
|
||||||
@ -1202,62 +1177,20 @@ class EngineArgs:
|
|||||||
if is_gguf(self.model):
|
if is_gguf(self.model):
|
||||||
self.quantization = self.load_format = "gguf"
|
self.quantization = self.load_format = "gguf"
|
||||||
|
|
||||||
# NOTE(woosuk): In V1, we use separate processes for workers (unless
|
if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
||||||
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
|
logger.warning(
|
||||||
# doesn't affect the user process.
|
"The global random seed is set to %d. Since "
|
||||||
if self.seed is None:
|
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
|
||||||
logger.warning_once(
|
"affect the random state of the Python process that "
|
||||||
"`seed=None` is equivalent to `seed=0` in V1 Engine. "
|
"launched vLLM.",
|
||||||
"You will no longer be allowed to pass `None` in v0.13.",
|
self.seed,
|
||||||
scope="local",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.seed = 0
|
|
||||||
if not envs.VLLM_ENABLE_V1_MULTIPROCESSING:
|
|
||||||
logger.warning(
|
|
||||||
"The global random seed is set to %d. Since "
|
|
||||||
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
|
|
||||||
"affect the random state of the Python process that "
|
|
||||||
"launched vLLM.",
|
|
||||||
self.seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.disable_mm_preprocessor_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`--disable-mm-preprocessor-cache` is deprecated "
|
|
||||||
"and will be removed in v0.13. "
|
|
||||||
"Please use `--mm-processor-cache-gb 0` instead.",
|
|
||||||
scope="local",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mm_processor_cache_gb = 0
|
|
||||||
elif envs.VLLM_MM_INPUT_CACHE_GIB != 4:
|
|
||||||
logger.warning_once(
|
|
||||||
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
|
|
||||||
"and will be removed in v0.13. "
|
|
||||||
"Please use `--mm-processor-cache-gb %d` instead.",
|
|
||||||
envs.VLLM_MM_INPUT_CACHE_GIB,
|
|
||||||
scope="local",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
|
|
||||||
|
|
||||||
if self.enable_multimodal_encoder_data_parallel:
|
|
||||||
logger.warning_once(
|
|
||||||
"--enable-multimodal-encoder-data-parallel` is deprecated "
|
|
||||||
"and will be removed in v0.13. "
|
|
||||||
"Please use `--mm-encoder-tp-mode data` instead.",
|
|
||||||
scope="local",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mm_encoder_tp_mode = "data"
|
|
||||||
|
|
||||||
return ModelConfig(
|
return ModelConfig(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
hf_config_path=self.hf_config_path,
|
hf_config_path=self.hf_config_path,
|
||||||
runner=self.runner,
|
runner=self.runner,
|
||||||
convert=self.convert,
|
convert=self.convert,
|
||||||
task=self.task,
|
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
tokenizer_mode=self.tokenizer_mode,
|
tokenizer_mode=self.tokenizer_mode,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
@ -1741,18 +1674,6 @@ class EngineArgs:
|
|||||||
|
|
||||||
# Compilation config overrides
|
# Compilation config overrides
|
||||||
compilation_config = copy.deepcopy(self.compilation_config)
|
compilation_config = copy.deepcopy(self.compilation_config)
|
||||||
if self.cuda_graph_sizes is not None:
|
|
||||||
logger.warning(
|
|
||||||
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
|
|
||||||
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
|
|
||||||
"instead."
|
|
||||||
)
|
|
||||||
if compilation_config.cudagraph_capture_sizes is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"cuda_graph_sizes and compilation_config."
|
|
||||||
"cudagraph_capture_sizes are mutually exclusive"
|
|
||||||
)
|
|
||||||
compilation_config.cudagraph_capture_sizes = self.cuda_graph_sizes
|
|
||||||
if self.cudagraph_capture_sizes is not None:
|
if self.cudagraph_capture_sizes is not None:
|
||||||
if compilation_config.cudagraph_capture_sizes is not None:
|
if compilation_config.cudagraph_capture_sizes is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1862,6 +1783,7 @@ class EngineArgs:
|
|||||||
except Exception:
|
except Exception:
|
||||||
# This is only used to set default_max_num_batched_tokens
|
# This is only used to set default_max_num_batched_tokens
|
||||||
device_memory = 0
|
device_memory = 0
|
||||||
|
device_name = ""
|
||||||
|
|
||||||
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
|
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
|
||||||
# throughput, see PR #17885 for more details.
|
# throughput, see PR #17885 for more details.
|
||||||
@ -1926,16 +1848,6 @@ class EngineArgs:
|
|||||||
default_chunked_prefill = model_config.is_chunked_prefill_supported
|
default_chunked_prefill = model_config.is_chunked_prefill_supported
|
||||||
default_prefix_caching = model_config.is_prefix_caching_supported
|
default_prefix_caching = model_config.is_prefix_caching_supported
|
||||||
|
|
||||||
if self.prefill_context_parallel_size > 1:
|
|
||||||
default_chunked_prefill = False
|
|
||||||
default_prefix_caching = False
|
|
||||||
logger.warning_once(
|
|
||||||
"--prefill-context-parallel-size > 1 is not compatible with "
|
|
||||||
"chunked prefill and prefix caching now. Chunked prefill "
|
|
||||||
"and prefix caching have been disabled by default.",
|
|
||||||
scope="local",
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.enable_chunked_prefill is None:
|
if self.enable_chunked_prefill is None:
|
||||||
self.enable_chunked_prefill = default_chunked_prefill
|
self.enable_chunked_prefill = default_chunked_prefill
|
||||||
|
|
||||||
@ -2121,11 +2033,13 @@ def human_readable_int(value):
|
|||||||
"k": 10**3,
|
"k": 10**3,
|
||||||
"m": 10**6,
|
"m": 10**6,
|
||||||
"g": 10**9,
|
"g": 10**9,
|
||||||
|
"t": 10**12,
|
||||||
}
|
}
|
||||||
binary_multiplier = {
|
binary_multiplier = {
|
||||||
"K": 2**10,
|
"K": 2**10,
|
||||||
"M": 2**20,
|
"M": 2**20,
|
||||||
"G": 2**30,
|
"G": 2**30,
|
||||||
|
"T": 2**40,
|
||||||
}
|
}
|
||||||
|
|
||||||
number, suffix = match.groups()
|
number, suffix = match.groups()
|
||||||
|
|||||||
@ -8,3 +8,5 @@ Shared constants for vLLM entrypoints.
|
|||||||
# These constants help mitigate header abuse attacks
|
# These constants help mitigate header abuse attacks
|
||||||
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
|
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT = 4194304 # 4 MB
|
||||||
H11_MAX_HEADER_COUNT_DEFAULT = 256
|
H11_MAX_HEADER_COUNT_DEFAULT = 256
|
||||||
|
|
||||||
|
MCP_PREFIX = "mcp_"
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from vllm import envs
|
|||||||
from vllm.entrypoints.chat_utils import (
|
from vllm.entrypoints.chat_utils import (
|
||||||
ChatTemplateContentFormatOption,
|
ChatTemplateContentFormatOption,
|
||||||
)
|
)
|
||||||
|
from vllm.entrypoints.constants import MCP_PREFIX
|
||||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||||
get_encoding,
|
get_encoding,
|
||||||
get_streamable_parser_for_assistant,
|
get_streamable_parser_for_assistant,
|
||||||
@ -303,7 +304,7 @@ class ParsableContext(ConversationContext):
|
|||||||
result_str = result.content[0].text
|
result_str = result.content[0].text
|
||||||
|
|
||||||
message = ResponseFunctionToolCallOutputItem(
|
message = ResponseFunctionToolCallOutputItem(
|
||||||
id=f"fco_{random_uuid()}",
|
id=f"mcpo_{random_uuid()}",
|
||||||
type="function_call_output",
|
type="function_call_output",
|
||||||
call_id=f"call_{random_uuid()}",
|
call_id=f"call_{random_uuid()}",
|
||||||
output=result_str,
|
output=result_str,
|
||||||
@ -385,6 +386,9 @@ class ParsableContext(ConversationContext):
|
|||||||
if not self.parser.response_messages:
|
if not self.parser.response_messages:
|
||||||
return []
|
return []
|
||||||
last_msg = self.parser.response_messages[-1]
|
last_msg = self.parser.response_messages[-1]
|
||||||
|
# change this to a mcp_ function call
|
||||||
|
last_msg.id = f"{MCP_PREFIX}{random_uuid()}"
|
||||||
|
self.parser.response_messages[-1] = last_msg
|
||||||
if last_msg.name == "code_interpreter":
|
if last_msg.name == "code_interpreter":
|
||||||
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
|
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
|
||||||
elif last_msg.name == "web_search_preview":
|
elif last_msg.name == "web_search_preview":
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import cloudpickle
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from typing_extensions import TypeVar, deprecated
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.beam_search import (
|
from vllm.beam_search import (
|
||||||
BeamSearchInstance,
|
BeamSearchInstance,
|
||||||
@ -73,7 +73,6 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
from vllm.tokenizers import MistralTokenizer, TokenizerLike
|
||||||
from vllm.tokenizers.hf import get_cached_tokenizer
|
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils.collection_utils import as_iter, is_list_of
|
from vllm.utils.collection_utils import as_iter, is_list_of
|
||||||
from vllm.utils.counter import Counter
|
from vllm.utils.counter import Counter
|
||||||
@ -199,7 +198,7 @@ class LLM:
|
|||||||
quantization: QuantizationMethods | None = None,
|
quantization: QuantizationMethods | None = None,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
tokenizer_revision: str | None = None,
|
tokenizer_revision: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int = 0,
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
swap_space: float = 4,
|
swap_space: float = 4,
|
||||||
cpu_offload_gb: float = 0,
|
cpu_offload_gb: float = 0,
|
||||||
@ -367,16 +366,6 @@ class LLM:
|
|||||||
def get_tokenizer(self) -> TokenizerLike:
|
def get_tokenizer(self) -> TokenizerLike:
|
||||||
return self.llm_engine.get_tokenizer()
|
return self.llm_engine.get_tokenizer()
|
||||||
|
|
||||||
@deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
|
|
||||||
def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
|
|
||||||
# While CachedTokenizer is dynamic, have no choice but
|
|
||||||
# compare class name. Misjudgment will arise from
|
|
||||||
# user-defined tokenizer started with 'Cached'
|
|
||||||
if tokenizer.__class__.__name__.startswith("Cached"):
|
|
||||||
self.llm_engine.tokenizer = tokenizer
|
|
||||||
else:
|
|
||||||
self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
|
|
||||||
|
|
||||||
def reset_mm_cache(self) -> None:
|
def reset_mm_cache(self) -> None:
|
||||||
self.input_processor.clear_mm_cache()
|
self.input_processor.clear_mm_cache()
|
||||||
self.llm_engine.reset_mm_cache()
|
self.llm_engine.reset_mm_cache()
|
||||||
|
|||||||
@ -176,7 +176,7 @@ class FrontendArgs:
|
|||||||
enable_force_include_usage: bool = False
|
enable_force_include_usage: bool = False
|
||||||
"""If set to True, including usage on every request."""
|
"""If set to True, including usage on every request."""
|
||||||
enable_tokenizer_info_endpoint: bool = False
|
enable_tokenizer_info_endpoint: bool = False
|
||||||
"""Enable the /get_tokenizer_info endpoint. May expose chat
|
"""Enable the `/tokenizer_info` endpoint. May expose chat
|
||||||
templates and other tokenizer configuration."""
|
templates and other tokenizer configuration."""
|
||||||
enable_log_outputs: bool = False
|
enable_log_outputs: bool = False
|
||||||
"""If True, log model outputs (generations).
|
"""If True, log model outputs (generations).
|
||||||
|
|||||||
@ -51,7 +51,11 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ToolCall,
|
ToolCall,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
from vllm.entrypoints.openai.serving_engine import (
|
||||||
|
GenerationError,
|
||||||
|
OpenAIServing,
|
||||||
|
clamp_prompt_logprobs,
|
||||||
|
)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
|
||||||
@ -380,6 +384,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
request_metadata,
|
request_metadata,
|
||||||
)
|
)
|
||||||
|
except GenerationError as e:
|
||||||
|
return self._convert_generation_error_to_response(e)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -1120,6 +1126,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
# if the model is finished generating
|
# if the model is finished generating
|
||||||
else:
|
else:
|
||||||
|
# check for error finish reason and abort streaming
|
||||||
|
# finish_reason='error' indicates a retryable error
|
||||||
|
self._raise_if_error(output.finish_reason, request_id)
|
||||||
|
|
||||||
# check to make sure we haven't "forgotten" to stream
|
# check to make sure we haven't "forgotten" to stream
|
||||||
# any tokens that were generated but previously
|
# any tokens that were generated but previously
|
||||||
# matched by partial json parsing
|
# matched by partial json parsing
|
||||||
@ -1287,6 +1297,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
delta=False,
|
delta=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
except GenerationError as e:
|
||||||
|
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
logger.exception("Error in chat completion stream generator.")
|
logger.exception("Error in chat completion stream generator.")
|
||||||
@ -1327,6 +1339,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
role = self.get_chat_request_role(request)
|
role = self.get_chat_request_role(request)
|
||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
|
# check for error finish reason and raise GenerationError
|
||||||
|
# finish_reason='error' indicates a retryable request-level internal error
|
||||||
|
self._raise_if_error(output.finish_reason, request_id)
|
||||||
token_ids = output.token_ids
|
token_ids = output.token_ids
|
||||||
out_logprobs = output.logprobs
|
out_logprobs = output.logprobs
|
||||||
tool_call_info = None
|
tool_call_info = None
|
||||||
|
|||||||
@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
RequestResponseMetadata,
|
RequestResponseMetadata,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
|
from vllm.entrypoints.openai.serving_engine import (
|
||||||
|
GenerationError,
|
||||||
|
OpenAIServing,
|
||||||
|
clamp_prompt_logprobs,
|
||||||
|
)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.renderer import RenderConfig
|
from vllm.entrypoints.renderer import RenderConfig
|
||||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||||
@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
|
except GenerationError as e:
|
||||||
|
return self._convert_generation_error_to_response(e)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
finish_reason = output.finish_reason
|
finish_reason = output.finish_reason
|
||||||
stop_reason = output.stop_reason
|
stop_reason = output.stop_reason
|
||||||
|
|
||||||
|
self._raise_if_error(finish_reason, request_id)
|
||||||
|
|
||||||
chunk = CompletionStreamResponse(
|
chunk = CompletionStreamResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
created=created_time,
|
created=created_time,
|
||||||
@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
# report to FastAPI middleware aggregate usage across all choices
|
# report to FastAPI middleware aggregate usage across all choices
|
||||||
request_metadata.final_usage_info = final_usage_info
|
request_metadata.final_usage_info = final_usage_info
|
||||||
|
|
||||||
|
except GenerationError as e:
|
||||||
|
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
logger.exception("Error in completion stream generator.")
|
||||||
data = self.create_streaming_error_response(str(e))
|
data = self.create_streaming_error_response(str(e))
|
||||||
yield f"data: {data}\n\n"
|
yield f"data: {data}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
|
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
|
||||||
|
|
||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
|
self._raise_if_error(output.finish_reason, request_id)
|
||||||
|
|
||||||
assert request.max_tokens is not None
|
assert request.max_tokens is not None
|
||||||
if request.echo:
|
if request.echo:
|
||||||
if request.return_token_ids:
|
if request.return_token_ids:
|
||||||
|
|||||||
@ -133,6 +133,15 @@ from vllm.utils.async_utils import (
|
|||||||
from vllm.utils.collection_utils import is_list_of
|
from vllm.utils.collection_utils import is_list_of
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationError(Exception):
|
||||||
|
"""raised when finish_reason indicates internal server error (500)"""
|
||||||
|
|
||||||
|
def __init__(self, message: str = "Internal server error"):
|
||||||
|
super().__init__(message)
|
||||||
|
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
CompletionLikeRequest: TypeAlias = (
|
CompletionLikeRequest: TypeAlias = (
|
||||||
@ -456,6 +465,29 @@ class OpenAIServing:
|
|||||||
# Iterate through all beam inference results
|
# Iterate through all beam inference results
|
||||||
for i, result in enumerate(output):
|
for i, result in enumerate(output):
|
||||||
current_beam = all_beams[i]
|
current_beam = all_beams[i]
|
||||||
|
|
||||||
|
# check for error finish reason and abort beam search
|
||||||
|
if result.outputs[0].finish_reason == "error":
|
||||||
|
# yield error output and terminate beam search
|
||||||
|
yield RequestOutput(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt_text,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
index=0,
|
||||||
|
text="",
|
||||||
|
token_ids=[],
|
||||||
|
cumulative_logprob=None,
|
||||||
|
logprobs=None,
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
finished=True,
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
prompt_logprobs=None,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if result.outputs[0].logprobs is not None:
|
if result.outputs[0].logprobs is not None:
|
||||||
logprobs = result.outputs[0].logprobs[0]
|
logprobs = result.outputs[0].logprobs[0]
|
||||||
all_beams_token_id.extend(list(logprobs.keys()))
|
all_beams_token_id.extend(list(logprobs.keys()))
|
||||||
@ -780,6 +812,35 @@ class OpenAIServing:
|
|||||||
)
|
)
|
||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
|
def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
|
||||||
|
"""Raise GenerationError if finish_reason indicates an error."""
|
||||||
|
if finish_reason == "error":
|
||||||
|
logger.error(
|
||||||
|
"Request %s failed with an internal error during generation",
|
||||||
|
request_id,
|
||||||
|
)
|
||||||
|
raise GenerationError("Internal server error")
|
||||||
|
|
||||||
|
def _convert_generation_error_to_response(
|
||||||
|
self, e: GenerationError
|
||||||
|
) -> ErrorResponse:
|
||||||
|
"""Convert GenerationError to ErrorResponse."""
|
||||||
|
return self.create_error_response(
|
||||||
|
str(e),
|
||||||
|
err_type="InternalServerError",
|
||||||
|
status_code=e.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_generation_error_to_streaming_response(
|
||||||
|
self, e: GenerationError
|
||||||
|
) -> str:
|
||||||
|
"""Convert GenerationError to streaming error response."""
|
||||||
|
return self.create_streaming_error_response(
|
||||||
|
str(e),
|
||||||
|
err_type="InternalServerError",
|
||||||
|
status_code=e.status_code,
|
||||||
|
)
|
||||||
|
|
||||||
async def _check_model(
|
async def _check_model(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
@ -1339,6 +1400,7 @@ class OpenAIServing:
|
|||||||
)
|
)
|
||||||
engine_prompt = engine_prompts[0]
|
engine_prompt = engine_prompts[0]
|
||||||
request_prompt = request_prompts[0]
|
request_prompt = request_prompts[0]
|
||||||
|
prompt_text, _, _ = self._get_prompt_components(request_prompt)
|
||||||
|
|
||||||
# Update the sampling params.
|
# Update the sampling params.
|
||||||
sampling_params.max_tokens = self.max_model_len - len(
|
sampling_params.max_tokens = self.max_model_len - len(
|
||||||
|
|||||||
@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import (
|
|||||||
)
|
)
|
||||||
from openai.types.responses.tool import Mcp, Tool
|
from openai.types.responses.tool import Mcp, Tool
|
||||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
@ -94,7 +95,10 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ResponseUsage,
|
ResponseUsage,
|
||||||
StreamingResponsesResponse,
|
StreamingResponsesResponse,
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import (
|
||||||
|
GenerationError,
|
||||||
|
OpenAIServing,
|
||||||
|
)
|
||||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
from vllm.entrypoints.responses_utils import (
|
from vllm.entrypoints.responses_utils import (
|
||||||
construct_input_messages,
|
construct_input_messages,
|
||||||
@ -541,6 +545,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
tokenizer,
|
tokenizer,
|
||||||
request_metadata,
|
request_metadata,
|
||||||
)
|
)
|
||||||
|
except GenerationError as e:
|
||||||
|
return self._convert_generation_error_to_response(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
@ -648,6 +654,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
status = "incomplete"
|
status = "incomplete"
|
||||||
elif context.finish_reason == "abort":
|
elif context.finish_reason == "abort":
|
||||||
status = "cancelled"
|
status = "cancelled"
|
||||||
|
else:
|
||||||
|
self._raise_if_error(context.finish_reason, request.request_id)
|
||||||
else:
|
else:
|
||||||
status = "incomplete"
|
status = "incomplete"
|
||||||
elif isinstance(context, ParsableContext):
|
elif isinstance(context, ParsableContext):
|
||||||
@ -673,6 +681,9 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
assert len(final_res.outputs) == 1
|
assert len(final_res.outputs) == 1
|
||||||
final_output = final_res.outputs[0]
|
final_output = final_res.outputs[0]
|
||||||
|
|
||||||
|
# finish_reason='error' indicates retryable internal error
|
||||||
|
self._raise_if_error(final_output.finish_reason, request.request_id)
|
||||||
|
|
||||||
output = self._make_response_output_items(request, final_output, tokenizer)
|
output = self._make_response_output_items(request, final_output, tokenizer)
|
||||||
|
|
||||||
if request.enable_response_messages:
|
if request.enable_response_messages:
|
||||||
@ -1066,6 +1077,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
async for event in generator:
|
async for event in generator:
|
||||||
event_deque.append(event)
|
event_deque.append(event)
|
||||||
new_event_signal.set() # Signal new event available
|
new_event_signal.set() # Signal new event available
|
||||||
|
except GenerationError as e:
|
||||||
|
response = self._convert_generation_error_to_response(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Background request failed for %s", request.request_id)
|
logger.exception("Background request failed for %s", request.request_id)
|
||||||
response = self.create_error_response(str(e))
|
response = self.create_error_response(str(e))
|
||||||
@ -1089,6 +1102,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
response = await self.responses_full_generator(request, *args, **kwargs)
|
response = await self.responses_full_generator(request, *args, **kwargs)
|
||||||
|
except GenerationError as e:
|
||||||
|
response = self._convert_generation_error_to_response(e)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Background request failed for %s", request.request_id)
|
logger.exception("Background request failed for %s", request.request_id)
|
||||||
response = self.create_error_response(str(e))
|
response = self.create_error_response(str(e))
|
||||||
@ -1227,6 +1242,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
continue
|
continue
|
||||||
if ctx.last_output.outputs:
|
if ctx.last_output.outputs:
|
||||||
output = ctx.last_output.outputs[0]
|
output = ctx.last_output.outputs[0]
|
||||||
|
# finish_reason='error' indicates a retryable error
|
||||||
|
self._raise_if_error(output.finish_reason, request.request_id)
|
||||||
if reasoning_parser:
|
if reasoning_parser:
|
||||||
delta_message = reasoning_parser.extract_reasoning_streaming(
|
delta_message = reasoning_parser.extract_reasoning_streaming(
|
||||||
previous_text=previous_text,
|
previous_text=previous_text,
|
||||||
@ -1522,6 +1539,9 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
async for ctx in result_generator:
|
async for ctx in result_generator:
|
||||||
assert isinstance(ctx, StreamingHarmonyContext)
|
assert isinstance(ctx, StreamingHarmonyContext)
|
||||||
|
|
||||||
|
# finish_reason='error' indicates a retryable error
|
||||||
|
self._raise_if_error(ctx.finish_reason, request.request_id)
|
||||||
|
|
||||||
if ctx.is_expecting_start():
|
if ctx.is_expecting_start():
|
||||||
current_output_index += 1
|
current_output_index += 1
|
||||||
sent_output_item_added = False
|
sent_output_item_added = False
|
||||||
@ -2016,18 +2036,25 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async for event_data in processer(
|
try:
|
||||||
request,
|
async for event_data in processer(
|
||||||
sampling_params,
|
request,
|
||||||
result_generator,
|
sampling_params,
|
||||||
context,
|
result_generator,
|
||||||
model_name,
|
context,
|
||||||
tokenizer,
|
model_name,
|
||||||
request_metadata,
|
tokenizer,
|
||||||
created_time,
|
request_metadata,
|
||||||
_increment_sequence_number_and_return,
|
created_time,
|
||||||
):
|
_increment_sequence_number_and_return,
|
||||||
yield event_data
|
):
|
||||||
|
yield event_data
|
||||||
|
except GenerationError as e:
|
||||||
|
error_json = self._convert_generation_error_to_streaming_response(e)
|
||||||
|
yield _increment_sequence_number_and_return(
|
||||||
|
TypeAdapter(StreamingResponsesResponse).validate_json(error_json)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
async def empty_async_generator():
|
async def empty_async_generator():
|
||||||
# A hack to trick Python to think this is a generator but
|
# A hack to trick Python to think this is a generator but
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
|||||||
from openai.types.responses.tool import Tool
|
from openai.types.responses.tool import Tool
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm.entrypoints.constants import MCP_PREFIX
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ResponseInputOutputItem,
|
ResponseInputOutputItem,
|
||||||
@ -44,13 +45,13 @@ def make_response_output_items_from_parsable_context(
|
|||||||
)
|
)
|
||||||
if isinstance(output_messages[-1], ResponseFunctionToolCall):
|
if isinstance(output_messages[-1], ResponseFunctionToolCall):
|
||||||
mcp_message = McpCall(
|
mcp_message = McpCall(
|
||||||
id=f"mcp_{random_uuid()}",
|
id=f"{MCP_PREFIX}{random_uuid()}",
|
||||||
arguments=output_messages[-1].arguments,
|
arguments=output_messages[-1].arguments,
|
||||||
name=output_messages[-1].name,
|
name=output_messages[-1].name,
|
||||||
server_label=output_messages[
|
server_label=output_messages[
|
||||||
-1
|
-1
|
||||||
].name, # TODO: store the server label
|
].name, # TODO: store the server label
|
||||||
type="mcp_call",
|
type=f"{MCP_PREFIX}call",
|
||||||
status="completed",
|
status="completed",
|
||||||
output=message.output,
|
output=message.output,
|
||||||
# TODO: support error output
|
# TODO: support error output
|
||||||
@ -98,12 +99,63 @@ def construct_input_messages(
|
|||||||
if isinstance(request_input, str):
|
if isinstance(request_input, str):
|
||||||
messages.append({"role": "user", "content": request_input})
|
messages.append({"role": "user", "content": request_input})
|
||||||
else:
|
else:
|
||||||
for item in request_input:
|
input_messages = construct_chat_messages_with_tool_call(request_input)
|
||||||
messages.append(construct_chat_message_with_tool_call(item))
|
messages.extend(input_messages)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def construct_chat_message_with_tool_call(
|
def _maybe_combine_reasoning_and_tool_call(
|
||||||
|
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
|
||||||
|
) -> ChatCompletionMessageParam | None:
|
||||||
|
"""Many models treat MCP calls and reasoning as a single message.
|
||||||
|
This function checks if the last message is a reasoning message and
|
||||||
|
the current message is a tool call"""
|
||||||
|
if not (
|
||||||
|
isinstance(item, ResponseFunctionToolCall) and item.id.startswith(MCP_PREFIX)
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
if len(messages) == 0:
|
||||||
|
return None
|
||||||
|
last_message = messages[-1]
|
||||||
|
if not (
|
||||||
|
last_message.get("role") == "assistant"
|
||||||
|
and last_message.get("reasoning") is not None
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
last_message["tool_calls"] = [
|
||||||
|
ChatCompletionMessageToolCallParam(
|
||||||
|
id=item.call_id,
|
||||||
|
function=FunctionCallTool(
|
||||||
|
name=item.name,
|
||||||
|
arguments=item.arguments,
|
||||||
|
),
|
||||||
|
type="function",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return last_message
|
||||||
|
|
||||||
|
|
||||||
|
def construct_chat_messages_with_tool_call(
|
||||||
|
input_messages: list[ResponseInputOutputItem],
|
||||||
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""This function wraps _construct_single_message_from_response_item
|
||||||
|
Because some chatMessages come from multiple response items
|
||||||
|
for example a reasoning item and a MCP tool call are two response items
|
||||||
|
but are one chat message
|
||||||
|
"""
|
||||||
|
messages: list[ChatCompletionMessageParam] = []
|
||||||
|
for item in input_messages:
|
||||||
|
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
|
||||||
|
if maybe_combined_message is not None:
|
||||||
|
messages[-1] = maybe_combined_message
|
||||||
|
else:
|
||||||
|
messages.append(_construct_single_message_from_response_item(item))
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_single_message_from_response_item(
|
||||||
item: ResponseInputOutputItem,
|
item: ResponseInputOutputItem,
|
||||||
) -> ChatCompletionMessageParam:
|
) -> ChatCompletionMessageParam:
|
||||||
if isinstance(item, ResponseFunctionToolCall):
|
if isinstance(item, ResponseFunctionToolCall):
|
||||||
|
|||||||
35
vllm/envs.py
35
vllm/envs.py
@ -72,10 +72,9 @@ if TYPE_CHECKING:
|
|||||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||||
VLLM_MEDIA_CONNECTOR: str = "http"
|
VLLM_MEDIA_CONNECTOR: str = "http"
|
||||||
VLLM_MM_INPUT_CACHE_GIB: int = 4
|
|
||||||
VLLM_TARGET_DEVICE: str = "cuda"
|
VLLM_TARGET_DEVICE: str = "cuda"
|
||||||
VLLM_MAIN_CUDA_VERSION: str = "12.9"
|
VLLM_MAIN_CUDA_VERSION: str = "12.9"
|
||||||
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
|
VLLM_FLOAT32_MATMUL_PRECISION: Literal["ieee", "tf32"] = "ieee"
|
||||||
MAX_JOBS: str | None = None
|
MAX_JOBS: str | None = None
|
||||||
NVCC_THREADS: str | None = None
|
NVCC_THREADS: str | None = None
|
||||||
VLLM_USE_PRECOMPILED: bool = False
|
VLLM_USE_PRECOMPILED: bool = False
|
||||||
@ -457,11 +456,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
|
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
|
||||||
or "12.9",
|
or "12.9",
|
||||||
# Controls PyTorch float32 matmul precision mode within vLLM workers.
|
# Controls PyTorch float32 matmul precision mode within vLLM workers.
|
||||||
# Valid options mirror torch.set_float32_matmul_precision
|
# Accepted values:
|
||||||
|
# - "ieee" (default): force full IEEE FP32 matmul precision.
|
||||||
|
# - "tf32": enable TensorFloat32-based fast matmul.
|
||||||
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
|
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
|
||||||
"VLLM_FLOAT32_MATMUL_PRECISION",
|
"VLLM_FLOAT32_MATMUL_PRECISION",
|
||||||
"highest",
|
"ieee",
|
||||||
["highest", "high", "medium"],
|
["ieee", "tf32"],
|
||||||
case_sensitive=False,
|
case_sensitive=False,
|
||||||
),
|
),
|
||||||
# Maximum number of compilation jobs to run in parallel.
|
# Maximum number of compilation jobs to run in parallel.
|
||||||
@ -786,9 +787,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# imported at runtime.
|
# imported at runtime.
|
||||||
# If a non-existing backend is used, an AssertionError will be thrown.
|
# If a non-existing backend is used, an AssertionError will be thrown.
|
||||||
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
|
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
|
||||||
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
|
|
||||||
# Default is 4 GiB per API process + 4 GiB per engine core process
|
|
||||||
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
|
|
||||||
# Path to the XLA persistent cache directory.
|
# Path to the XLA persistent cache directory.
|
||||||
# Only used for XLA devices such as TPUs.
|
# Only used for XLA devices such as TPUs.
|
||||||
"VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser(
|
"VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser(
|
||||||
@ -1580,6 +1578,12 @@ def __getattr__(name: str):
|
|||||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_envs_cache_enabled() -> bool:
|
||||||
|
"""Checked if __getattr__ is wrapped with functools.cache"""
|
||||||
|
global __getattr__
|
||||||
|
return hasattr(__getattr__, "cache_clear")
|
||||||
|
|
||||||
|
|
||||||
def enable_envs_cache() -> None:
|
def enable_envs_cache() -> None:
|
||||||
"""
|
"""
|
||||||
Enables caching of environment variables. This is useful for performance
|
Enables caching of environment variables. This is useful for performance
|
||||||
@ -1590,6 +1594,9 @@ def enable_envs_cache() -> None:
|
|||||||
runtime overhead. This also means that environment variables should NOT
|
runtime overhead. This also means that environment variables should NOT
|
||||||
be updated after the service is initialized.
|
be updated after the service is initialized.
|
||||||
"""
|
"""
|
||||||
|
if _is_envs_cache_enabled():
|
||||||
|
# Avoid wrapping functools.cache multiple times
|
||||||
|
return
|
||||||
# Tag __getattr__ with functools.cache
|
# Tag __getattr__ with functools.cache
|
||||||
global __getattr__
|
global __getattr__
|
||||||
__getattr__ = functools.cache(__getattr__)
|
__getattr__ = functools.cache(__getattr__)
|
||||||
@ -1599,6 +1606,17 @@ def enable_envs_cache() -> None:
|
|||||||
__getattr__(key)
|
__getattr__(key)
|
||||||
|
|
||||||
|
|
||||||
|
def disable_envs_cache() -> None:
|
||||||
|
"""
|
||||||
|
Resets the environment variables cache. It could be used to isolate environments
|
||||||
|
between unit tests.
|
||||||
|
"""
|
||||||
|
global __getattr__
|
||||||
|
# If __getattr__ is wrapped by functions.cache, unwrap the caching layer.
|
||||||
|
if _is_envs_cache_enabled():
|
||||||
|
__getattr__ = __getattr__.__wrapped__
|
||||||
|
|
||||||
|
|
||||||
def __dir__():
|
def __dir__():
|
||||||
return list(environment_variables.keys())
|
return list(environment_variables.keys())
|
||||||
|
|
||||||
@ -1661,7 +1679,6 @@ def compile_factors() -> dict[str, object]:
|
|||||||
"VLLM_MEDIA_CONNECTOR",
|
"VLLM_MEDIA_CONNECTOR",
|
||||||
"VLLM_ASSETS_CACHE",
|
"VLLM_ASSETS_CACHE",
|
||||||
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
|
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
|
||||||
"VLLM_MM_INPUT_CACHE_GIB",
|
|
||||||
"VLLM_WORKER_MULTIPROC_METHOD",
|
"VLLM_WORKER_MULTIPROC_METHOD",
|
||||||
"VLLM_ENABLE_V1_MULTIPROCESSING",
|
"VLLM_ENABLE_V1_MULTIPROCESSING",
|
||||||
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
|
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
|
||||||
|
|||||||
@ -4,7 +4,10 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEConfig,
|
||||||
|
RoutingMethodType,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
)
|
)
|
||||||
@ -49,6 +52,7 @@ __all__ = [
|
|||||||
"FusedMoEPermuteExpertsUnpermute",
|
"FusedMoEPermuteExpertsUnpermute",
|
||||||
"FusedMoEActivationFormat",
|
"FusedMoEActivationFormat",
|
||||||
"FusedMoEPrepareAndFinalize",
|
"FusedMoEPrepareAndFinalize",
|
||||||
|
"RoutingMethodType",
|
||||||
"SharedFusedMoE",
|
"SharedFusedMoE",
|
||||||
"activation_without_mul",
|
"activation_without_mul",
|
||||||
"override_config",
|
"override_config",
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user