Merge branch 'mlm-full-lora-support' of https://github.com/jeejeelee/vllm into mlm-full-lora-support

This commit is contained in:
bk-201 2025-12-04 16:58:16 +00:00
commit f67ccfae9c
138 changed files with 5654 additions and 3230 deletions

View File

@ -1,46 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import os
template = """<!DOCTYPE html>
<html>
<body>
<h1>Links for vLLM</h1/>
<a href="../{x86_wheel_html_escaped}">{x86_wheel}</a><br/>
<a href="../{arm_wheel_html_escaped}">{arm_wheel}</a><br/>
</body>
</html>
"""
parser = argparse.ArgumentParser()
parser.add_argument("--wheel", help="The wheel path.", required=True)
args = parser.parse_args()
filename = os.path.basename(args.wheel)
with open("index.html", "w") as f:
print(f"Generated index.html for {args.wheel}")
# sync the abi tag with .buildkite/scripts/upload-wheels.sh
if "x86_64" in filename:
x86_wheel = filename
arm_wheel = filename.replace("x86_64", "aarch64").replace(
"manylinux1", "manylinux2014"
)
elif "aarch64" in filename:
x86_wheel = filename.replace("aarch64", "x86_64").replace(
"manylinux2014", "manylinux1"
)
arm_wheel = filename
else:
raise ValueError(f"Unsupported wheel: {filename}")
# cloudfront requires escaping the '+' character
f.write(
template.format(
x86_wheel=x86_wheel,
x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"),
arm_wheel=arm_wheel,
arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"),
)
)

View File

@ -7,13 +7,14 @@
import argparse
import json
import re
import sys
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any
from urllib.parse import quote
import regex as re
if not sys.version_info >= (3, 12):
raise RuntimeError("This script requires Python 3.12 or higher.")

View File

@ -74,6 +74,7 @@ FROM ${BASE_IMAGE_NAME}
# Define environments
ENV DEBIAN_FRONTEND=noninteractive
ENV SOC_VERSION="ascend910b1"
RUN pip config set global.index-url http://cache-service-vllm.nginx-pypi-cache.svc.cluster.local:${PYPI_CACHE_PORT}/pypi/simple && \
pip config set global.trusted-host cache-service-vllm.nginx-pypi-cache.svc.cluster.local && \

View File

@ -81,7 +81,7 @@ else
alias_arg=""
fi
$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" $alias_arg
$PYTHON pip install regex && .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" $alias_arg
# copy indices to /<commit>/ unconditionally
echo "Uploading indices to $S3_COMMIT_PREFIX"

View File

@ -987,7 +987,8 @@ steps:
commands:
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1
- label: Multi-Modal Models Test (Extended) 1
- label: Multi-Modal Models Test (Extended) 1 # 60min
timeout_in_minutes: 120
mirror_hardwares: [amdexperimental]
agent_pool: mi325_1
# grade: Blocking
@ -1011,7 +1012,8 @@ steps:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
- label: Multi-Modal Models Test (Extended) 3
- label: Multi-Modal Models Test (Extended) 3 # 75min
timeout_in_minutes: 150
mirror_hardwares: [amdexperimental]
agent_pool: mi325_1
# grade: Blocking

View File

@ -387,6 +387,7 @@ steps:
working_dir: "/vllm-workspace/examples"
source_file_dependencies:
- vllm/entrypoints
- vllm/multimodal
- examples/
commands:
- pip install tensorizer # for tensorizer test

View File

@ -137,6 +137,7 @@ Compute Resources:
- Alibaba Cloud
- AMD
- Anyscale
- Arm
- AWS
- Crusoe Cloud
- Databricks

View File

@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Micro benchmark comparing built-in hash(), SHA-256, and xxHash.
This focuses on a single test payload shaped like the prefix-cache hash input:
(32-byte bytes object, 32-int tuple)
Usage:
python benchmarks/hash_micro_benchmark.py --iterations 20000
"""
from __future__ import annotations
import argparse
import random
import statistics
import time
from collections.abc import Callable, Iterable
from vllm.utils.hashing import sha256, xxhash
def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]:
"""Generate a deterministic test payload."""
random.seed(seed)
bytes_data = bytes(random.getrandbits(8) for _ in range(32))
int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32))
return (bytes_data, int_tuple)
def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int):
"""Return (avg_seconds, std_seconds) for hashing `data` `iterations` times."""
times: list[float] = []
# Warm-up to avoid first-run noise.
for _ in range(200):
func(data)
for _ in range(iterations):
start = time.perf_counter()
func(data)
end = time.perf_counter()
times.append(end - start)
avg = statistics.mean(times)
std = statistics.stdev(times) if len(times) > 1 else 0.0
return avg, std
def _run_benchmarks(
benchmarks: Iterable[tuple[str, Callable[[tuple], object]]],
data: tuple,
iterations: int,
):
"""Yield (name, avg, std) for each benchmark, skipping unavailable ones."""
for name, func in benchmarks:
try:
avg, std = _benchmark_func(func, data, iterations)
except ModuleNotFoundError as exc:
print(f"Skipping {name}: {exc}")
continue
yield name, avg, std
def builtin_hash(data: tuple) -> int:
"""Wrapper for Python's built-in hash()."""
return hash(data)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--iterations",
type=int,
default=10_000,
help="Number of measured iterations per hash function.",
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for test payload."
)
args = parser.parse_args()
data = _generate_test_data(args.seed)
benchmarks = (
("SHA256 (pickle)", sha256),
("xxHash (pickle)", xxhash),
("built-in hash()", builtin_hash),
)
print("=" * 60)
print("HASH FUNCTION MICRO BENCHMARK")
print("=" * 60)
print("Test data: (32-byte bytes object, 32-int tuple)")
print(f"Iterations: {args.iterations:,}")
print("=" * 60)
results = list(_run_benchmarks(benchmarks, data, args.iterations))
builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None)
print("\nResults:")
for name, avg, std in results:
print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs")
if builtin_entry:
_, builtin_avg, _ = builtin_entry
print("\n" + "=" * 60)
print("SUMMARY (relative to built-in hash())")
print("=" * 60)
for name, avg, _ in results:
if name == "built-in hash()":
continue
speed_ratio = avg / builtin_avg
print(f"{name} is {speed_ratio:.1f}x slower than built-in hash()")
else:
print("\nBuilt-in hash() result missing; cannot compute speed ratios.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,110 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Simple benchmark to compare prefix-cache block hashing algorithms.
Example:
python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32
"""
from __future__ import annotations
import argparse
import random
import statistics
import sys
import time
from collections.abc import Callable, Iterable, Sequence
from vllm.utils.hashing import get_hash_fn_by_name
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
def _generate_blocks(
num_blocks: int, block_size: int, vocab_size: int, seed: int
) -> list[list[int]]:
rng = random.Random(seed)
return [
[rng.randrange(vocab_size) for _ in range(block_size)]
for _ in range(num_blocks)
]
def _hash_all_blocks(
hash_fn: Callable[[object], bytes],
blocks: Iterable[Sequence[int]],
) -> float:
parent_hash: BlockHash | None = None
start = time.perf_counter()
for block in blocks:
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
end = time.perf_counter()
return end - start
def _benchmark(
hash_algo: str,
blocks: list[list[int]],
trials: int,
) -> tuple[float, float, float] | None:
try:
hash_fn = get_hash_fn_by_name(hash_algo)
init_none_hash(hash_fn)
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
except ModuleNotFoundError as exc:
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
return None
avg = statistics.mean(timings)
best = min(timings)
# throughput: tokens / second
tokens_hashed = len(blocks) * len(blocks[0])
throughput = tokens_hashed / best
return avg, best, throughput
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
parser.add_argument(
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
)
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
parser.add_argument(
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
)
parser.add_argument(
"--algorithms",
nargs="+",
default=SUPPORTED_ALGOS,
choices=SUPPORTED_ALGOS,
help="Hash algorithms to benchmark.",
)
args = parser.parse_args()
blocks = _generate_blocks(
args.num_blocks, args.block_size, args.vocab_size, args.seed
)
print(
f"Benchmarking {len(args.algorithms)} algorithms on "
f"{args.num_blocks} blocks (block size={args.block_size})."
)
for algo in args.algorithms:
result = _benchmark(algo, blocks, args.trials)
if result is None:
continue
avg, best, throughput = result
print(
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
f"throughput: {throughput / 1e6:.2f}M tokens/s"
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,244 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import Enum
from itertools import product
from typing import Any
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.triton_utils import triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
from .utils import ArgPool, Bench, CudaGraphBenchParams
GROUP_SIZE = 128
FLOAT8_T = torch.float8_e4m3fn
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
print(
f"Note : The timings reported above is for {cuda_graph_nops} "
"consecutive invocations of the benchmarking functions. "
f"Please divide by {cuda_graph_nops} for single invocation "
"timings."
)
compare = TBenchmark.Compare(timers)
compare.print()
class ImplType(Enum):
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
REFERENCE = 2
def get_impl(self):
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
return silu_mul_per_token_group_quant_fp8_colmajor
elif self == ImplType.REFERENCE:
return reference
raise ValueError(f"Unrecognized ImplType {self}")
@dataclass
class BenchmarkTensors:
input: torch.Tensor
output: torch.Tensor
# Reference act output tensor
ref_act_out: torch.Tensor
ref_quant_out: torch.Tensor
@staticmethod
def make(T: int, N: int) -> "BenchmarkTensors":
assert T % GROUP_SIZE == 0
assert N % (GROUP_SIZE * 2) == 0
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
# silu_mul_per_token_group_quant_fp8_colmajor output.
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
FLOAT8_T
)
# reference output.
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
ref_quant_out = torch.empty(
(T, N // 2), dtype=torch.bfloat16, device="cuda"
).to(FLOAT8_T)
return BenchmarkTensors(
input=input,
output=output,
ref_act_out=ref_act_out,
ref_quant_out=ref_quant_out,
)
@property
def T(self):
return self.input.size(0)
@property
def N(self):
return self.input.size(1)
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
return {
"input": self.input,
"output": self.output,
"use_ue8m0": is_deep_gemm_e8m0_used(),
}
elif impl_type == ImplType.REFERENCE:
return {
"input": self.input,
"act_out": self.ref_act_out,
"quant_out": self.ref_quant_out,
"use_ue8m0": is_deep_gemm_e8m0_used(),
}
raise ValueError(f"Unrecognized impl_type {impl_type}")
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
"""
Reference triton quant kernel from,
vllm.model_executor.layers.quantization.utils.fp8_utils
"""
assert quant_out.size() == x.size()
# Allocate the scale tensor column-major format.
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
x_q = quant_out
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
M = x.numel() // GROUP_SIZE
N = GROUP_SIZE
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
finfo = torch.finfo(FLOAT8_T)
fp8_min = finfo.min
fp8_max = finfo.max
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
GROUP_SIZE,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def reference(
input: torch.Tensor,
act_out: torch.Tensor,
quant_out: torch.Tensor,
use_ue8m0: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
torch.ops._C.silu_and_mul(act_out, input)
return reference_quant(act_out, quant_out, use_ue8m0)
def bench_impl(
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
) -> TMeasurement:
T = bench_tensors[0].T
N = bench_tensors[0].N
arg_pool_size = len(bench_tensors)
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
# warmup
for kwargs in kwargs_list:
impl_type.get_impl()(**kwargs)
torch.cuda.synchronize()
# Merge into a single kwargs and qualify arguments as ArgPool
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
for _kwargs in kwargs_list:
for k, v in _kwargs.items():
kwargs[k].values.append(v)
cuda_graph_params = None
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
timer = None
with Bench(
cuda_graph_params,
"silu-mul-quant",
f"num_tokens={T}, N={N}",
impl_type.name,
impl_type.get_impl(),
**kwargs,
) as bench:
timer = bench.run()
return timer
def test_correctness(T: int, N: int):
print(f"Testing num_tokens={T}, N={N} ...")
bench_tensor = BenchmarkTensors.make(T, N)
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
# reference output
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
# test ouptut
out_q, out_s = output_from_impl(
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
)
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
torch.testing.assert_close(ref_out_s, out_s)
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
timers = []
for N, T in product(Ns, Ts):
test_correctness(T, N)
bench_tensors: list[BenchmarkTensors] = [
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
]
silu_mul_quant_timer = bench_impl(
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
)
timers.append(silu_mul_quant_timer)
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
timers.append(reference_timer)
print_timers(
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
)
print_timers(timers, cuda_graph_nops=arg_pool_size)
return timers
if __name__ == "__main__":
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
N = [2048, 4096, 8192]
print(f"T = {T}, N = {N}")
run(T, N, arg_pool_size=8)

View File

@ -150,6 +150,97 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
#################### BASE BUILD IMAGE ####################
#################### CSRC BUILD IMAGE ####################
FROM base AS csrc-build
ARG TARGETPLATFORM
ARG PIP_INDEX_URL UV_INDEX_URL
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
ARG PYTORCH_CUDA_INDEX_BASE_URL
# install build dependencies
COPY requirements/build.txt requirements/build.txt
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694
ENV UV_HTTP_TIMEOUT=500
ENV UV_INDEX_STRATEGY="unsafe-best-match"
# Use copy mode to avoid hardlink failures with Docker cache mounts
ENV UV_LINK_MODE=copy
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
WORKDIR /workspace
COPY pyproject.toml setup.py CMakeLists.txt ./
COPY cmake cmake/
COPY csrc csrc/
COPY vllm/envs.py vllm/envs.py
COPY vllm/__init__.py vllm/__init__.py
# max jobs used by Ninja to build extensions
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
# number of threads used by nvcc
ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads
ARG USE_SCCACHE
ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz
ARG SCCACHE_ENDPOINT
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
ARG SCCACHE_REGION_NAME=us-west-2
ARG SCCACHE_S3_NO_CREDENTIALS=0
# Flag to control whether to use pre-built vLLM wheels
ARG VLLM_USE_PRECOMPILED=""
ARG VLLM_MERGE_BASE_COMMIT=""
ARG VLLM_MAIN_CUDA_VERSION=""
# Use dummy version for csrc-build wheel (only .so files are extracted, version doesn't matter)
ENV SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0+csrc.build"
# if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "$USE_SCCACHE" = "1" ]; then \
echo "Installing sccache..." \
&& curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \
&& tar -xzf sccache.tar.gz \
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
&& if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
&& export SCCACHE_IDLE_TIMEOUT=0 \
&& export CMAKE_BUILD_TYPE=Release \
&& export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \
&& export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" \
&& export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \
&& export VLLM_DOCKER_BUILD_CONTEXT=1 \
&& sccache --show-stats \
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
&& sccache --show-stats; \
fi
ARG vllm_target_device="cuda"
ENV VLLM_TARGET_DEVICE=${vllm_target_device}
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
if [ "$USE_SCCACHE" != "1" ]; then \
# Clean any existing CMake artifacts
rm -rf .deps && \
mkdir -p .deps && \
export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \
export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" && \
export VLLM_DOCKER_BUILD_CONTEXT=1 && \
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
fi
#################### CSRC BUILD IMAGE ####################
#################### WHEEL BUILD IMAGE ####################
FROM base AS build
ARG TARGETPLATFORM
@ -172,66 +263,28 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
WORKDIR /workspace
COPY --from=csrc-build /workspace/dist /precompiled-wheels
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
# max jobs used by Ninja to build extensions
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
# number of threads used by nvcc
ARG nvcc_threads=8
ENV NVCC_THREADS=$nvcc_threads
ARG USE_SCCACHE
ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz
ARG SCCACHE_ENDPOINT
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
ARG SCCACHE_REGION_NAME=us-west-2
ARG SCCACHE_S3_NO_CREDENTIALS=0
# Flag to control whether to use pre-built vLLM wheels
ARG VLLM_USE_PRECOMPILED=""
ARG VLLM_MAIN_CUDA_VERSION=""
# if USE_SCCACHE is set, use sccache to speed up compilation
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" = "1" ]; then \
echo "Installing sccache..." \
&& curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \
&& tar -xzf sccache.tar.gz \
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
&& if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
&& export SCCACHE_IDLE_TIMEOUT=0 \
&& export CMAKE_BUILD_TYPE=Release \
&& export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \
&& export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \
&& export VLLM_DOCKER_BUILD_CONTEXT=1 \
&& sccache --show-stats \
&& python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \
&& sccache --show-stats; \
fi
ARG vllm_target_device="cuda"
ENV VLLM_TARGET_DEVICE=${vllm_target_device}
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "$USE_SCCACHE" != "1" ]; then \
# Clean any existing CMake artifacts
rm -rf .deps && \
mkdir -p .deps && \
export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \
export VLLM_DOCKER_BUILD_CONTEXT=1 && \
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \
fi
# Skip adding +precompiled suffix to version (preserves git-derived version)
ENV VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX=1
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=.git,target=.git \
if [ "${vllm_target_device}" = "cuda" ]; then \
export VLLM_PRECOMPILED_WHEEL_LOCATION=$(ls /precompiled-wheels/*.whl); \
fi && \
python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38
# Install DeepGEMM from source
ARG DEEPGEMM_GIT_REF
@ -527,7 +580,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
else \
BITSANDBYTES_VERSION="0.46.1"; \
fi; \
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.0'
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3'
ENV VLLM_USAGE_SOURCE production-docker-image

Binary file not shown.

Before

Width:  |  Height:  |  Size: 146 KiB

After

Width:  |  Height:  |  Size: 174 KiB

View File

@ -670,6 +670,35 @@ vllm bench serve \
</details>
### 🧪 Hashing Benchmarks
<details class="admonition abstract" markdown="1">
<summary>Show more</summary>
Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production.
- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload.
```bash
python benchmarks/benchmark_hash.py --iterations 20000 --seed 42
```
- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput.
```bash
python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5
```
Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants:
```bash
uv pip install xxhash cbor2
```
If an algorithms dependency is missing, the script will skip it and continue.
</details>
### ⚡ Request Prioritization Benchmark
<details class="admonition abstract" markdown="1">

View File

@ -18,6 +18,7 @@ Compute Resources:
- Alibaba Cloud
- AMD
- Anyscale
- Arm
- AWS
- Crusoe Cloud
- Databricks

View File

@ -57,15 +57,15 @@ vLLM also provides [a reference example](../../examples/online_serving/prometheu
The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important:
- `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds.
- `vllm:prompt_tokens_total` - Prompt tokens.
- `vllm:generation_tokens_total` - Generation tokens.
- `vllm:prompt_tokens` - Prompt tokens.
- `vllm:generation_tokens` - Generation tokens.
- `vllm:time_per_output_token_seconds` - Inter-token latency (Time Per Output Token, TPOT) in seconds.
- `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds.
- `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in the RUNNING, WAITING, and SWAPPED states.
- `vllm:gpu_cache_usage_perc` - Percentage of used cache blocks by vLLM.
- `vllm:request_prompt_tokens` - Request prompt length.
- `vllm:request_generation_tokens` - Request generation length.
- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.
- `vllm:request_success` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.
- `vllm:request_queue_time_seconds` - Queue time.
- `vllm:request_prefill_time_seconds` - Requests prefill time.
- `vllm:request_decode_time_seconds` - Requests decode time.
@ -571,9 +571,9 @@ model and then validate those tokens with the larger model.
- `vllm:spec_decode_draft_acceptance_rate` (Gauge)
- `vllm:spec_decode_efficiency` (Gauge)
- `vllm:spec_decode_num_accepted_tokens_total` (Counter)
- `vllm:spec_decode_num_draft_tokens_total` (Counter)
- `vllm:spec_decode_num_emitted_tokens_total` (Counter)
- `vllm:spec_decode_num_accepted_tokens` (Counter)
- `vllm:spec_decode_num_draft_tokens` (Counter)
- `vllm:spec_decode_num_emitted_tokens` (Counter)
There is a PR under review (<https://github.com/vllm-project/vllm/pull/12193>) to add "prompt lookup (ngram)"
speculative decoding to v1. Other techniques will follow. We should

View File

@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|---------|-----------------------------------------|----------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |

View File

@ -54,7 +54,7 @@ th:not(:first-child) {
| beam-search | ✅ | ✅ | ✅ | [](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | |
| [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention.
<sup>^</sup> LoRA is only applicable to the language backbone of multimodal models.
### Feature x Hardware

View File

@ -0,0 +1,58 @@
# MooncakeConnector Usage Guide
## About Mooncake
Mooncake aims to enhance the inference efficiency of large language models (LLMs), especially in slow object storage environments, by constructing a multi-level caching pool on high-speed interconnected DRAM/SSD resources. Compared to traditional caching systems, Mooncake utilizes (GPUDirect) RDMA technology to transfer data directly in a zero-copy manner, while maximizing the use of multi-NIC resources on a single machine.
For more details about Mooncake, please refer to [Mooncake project](https://github.com/kvcache-ai/Mooncake) and [Mooncake documents](https://kvcache-ai.github.io/Mooncake/).
## Prerequisites
### Installation
Install mooncake through pip: `uv pip install mooncake-transfer-engine`.
Refer to [Mooncake official repository](https://github.com/kvcache-ai/Mooncake) for more installation instructions
## Usage
### Prefiller Node (192.168.0.2)
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8010 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}'
```
### Decoder Node (192.168.0.3)
```bash
vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}'
```
### Proxy
```bash
python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020
```
> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future.
Now you can send requests to the proxy server through port 8000.
## Environment Variables
- `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server
- Default: 8998
- Required only for prefiller instances
- Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine
- For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank
- Used for the decoder notifying the prefiller
- `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefillers KV cache for a particular request. (Optional)
- Default: 480
- If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely.
## KV Role Options
- **kv_producer**: For prefiller instances that generate KV caches
- **kv_consumer**: For decoder instances that consume KV caches from prefiller
- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined.

View File

@ -795,14 +795,12 @@ The following example demonstrates how to pass image embeddings to the OpenAI se
??? code
```python
from vllm.utils.serial_utils import tensor2base64
image_embedding = torch.load(...)
grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct
buffer = io.BytesIO()
torch.save(image_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_image_embedding = base64.b64encode(binary_data).decode('utf-8')
base64_image_embedding = tensor2base64(image_embedding)
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")

View File

@ -4,9 +4,6 @@ vLLM has experimental support for macOS with Apple Silicon. For now, users must
Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
!!! warning
There are no pre-built wheels or images for this device, so you must build vLLM from source.
# --8<-- [end:installation]
# --8<-- [start:requirements]
@ -20,6 +17,8 @@ Currently the CPU implementation for macOS supports FP32 and FP16 datatypes.
# --8<-- [end:set-up-using-python]
# --8<-- [start:pre-built-wheels]
Currently, there are no pre-built Apple silicon CPU wheels.
# --8<-- [end:pre-built-wheels]
# --8<-- [start:build-wheel-from-source]
@ -78,6 +77,8 @@ uv pip install -e .
# --8<-- [end:build-wheel-from-source]
# --8<-- [start:pre-built-images]
Currently, there are no pre-built Arm silicon CPU images.
# --8<-- [end:pre-built-images]
# --8<-- [start:build-image-from-source]

View File

@ -1,11 +1,6 @@
# --8<-- [start:installation]
vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform.
ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes.
!!! warning
There are no pre-built wheels or images for this device, so you must build vLLM from source.
vLLM offers basic model inferencing and serving on Arm CPU platform, with support NEON, data types FP32, FP16 and BF16.
# --8<-- [end:installation]
# --8<-- [start:requirements]
@ -20,6 +15,23 @@ ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes.
# --8<-- [end:set-up-using-python]
# --8<-- [start:pre-built-wheels]
Pre-built vLLM wheels for Arm are available since version 0.11.2. These wheels contain pre-compiled C++ binaries.
Please replace `<version>` in the commands below with a specific version string (e.g., `0.11.2`).
```bash
uv pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.ai/<version>%2Bcpu/
```
??? console "pip"
```bash
pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.ai/<version>%2Bcpu/
```
The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version.
!!! note
Nightly wheels are currently unsupported for this architecture. (e.g. to bisect the behavior change, performance regression).
# --8<-- [end:pre-built-wheels]
# --8<-- [start:build-wheel-from-source]
@ -69,6 +81,8 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
# --8<-- [end:build-wheel-from-source]
# --8<-- [start:pre-built-images]
Currently, there are no pre-built Arm CPU images.
# --8<-- [end:pre-built-images]
# --8<-- [start:build-image-from-source]
```bash

View File

@ -46,11 +46,25 @@ vLLM is a Python library that supports the following CPU variants. Select your C
### Pre-built wheels
Please refer to the instructions for [pre-built wheels on GPU](./gpu.md#pre-built-wheels).
When specifying the index URL, please make sure to use the `cpu` variant subdirectory.
For example, the nightly build index is: `https://wheels.vllm.ai/nightly/cpu/`.
=== "Intel/AMD x86"
--8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-wheels"
=== "ARM AArch64"
--8<-- "docs/getting_started/installation/cpu.arm.inc.md:pre-built-wheels"
=== "Apple silicon"
--8<-- "docs/getting_started/installation/cpu.apple.inc.md:pre-built-wheels"
=== "IBM Z (S390X)"
--8<-- "docs/getting_started/installation/cpu.s390x.inc.md:pre-built-wheels"
### Build wheel from source
#### Set up using Python-only build (without compilation) {#python-only-build}
@ -87,6 +101,18 @@ VLLM_USE_PRECOMPILED=1 VLLM_PRECOMPILED_WHEEL_VARIANT=cpu VLLM_TARGET_DEVICE=cpu
--8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-images"
=== "ARM AArch64"
--8<-- "docs/getting_started/installation/cpu.arm.inc.md:pre-built-images"
=== "Apple silicon"
--8<-- "docs/getting_started/installation/cpu.apple.inc.md:pre-built-images"
=== "IBM Z (S390X)"
--8<-- "docs/getting_started/installation/cpu.s390x.inc.md:pre-built-images"
### Build image from source
=== "Intel/AMD x86"

View File

@ -4,9 +4,6 @@ vLLM has experimental support for s390x architecture on IBM Z platform. For now,
Currently, the CPU implementation for s390x architecture supports FP32 datatype only.
!!! warning
There are no pre-built wheels or images for this device, so you must build vLLM from source.
# --8<-- [end:installation]
# --8<-- [start:requirements]
@ -21,6 +18,8 @@ Currently, the CPU implementation for s390x architecture supports FP32 datatype
# --8<-- [end:set-up-using-python]
# --8<-- [start:pre-built-wheels]
Currently, there are no pre-built IBM Z CPU wheels.
# --8<-- [end:pre-built-wheels]
# --8<-- [start:build-wheel-from-source]
@ -69,6 +68,8 @@ Execute the following commands to build and install vLLM from source.
# --8<-- [end:build-wheel-from-source]
# --8<-- [start:pre-built-images]
Currently, there are no pre-built IBM Z CPU images.
# --8<-- [end:pre-built-images]
# --8<-- [start:build-image-from-source]

View File

@ -17,6 +17,8 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data
# --8<-- [end:set-up-using-python]
# --8<-- [start:pre-built-wheels]
Currently, there are no pre-built x86 CPU wheels.
# --8<-- [end:pre-built-wheels]
# --8<-- [start:build-wheel-from-source]

View File

@ -5,9 +5,6 @@ vLLM supports AMD GPUs with ROCm 6.3 or above, and torch 2.8.0 and above.
!!! tip
[Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm.
!!! warning
There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source.
# --8<-- [end:installation]
# --8<-- [start:requirements]

View File

@ -2,9 +2,6 @@
vLLM initially supports basic model inference and serving on Intel GPU platform.
!!! warning
There are no pre-built wheels for this device, so you need build vLLM from source. Or you can use pre-built images which are based on vLLM released versions.
# --8<-- [end:installation]
# --8<-- [start:requirements]

View File

@ -711,7 +711,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ |
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ |
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ |
| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ |
| `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ |
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ |
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ |

View File

@ -23,31 +23,23 @@ def create_test_prompts(
# this is an example of using quantization without LoRA
(
"My name is",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
None,
),
# the next three examples use quantization with LoRA
(
"my name is",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("lora-test-1", 1, lora_path),
),
(
"The capital of USA is",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("lora-test-2", 1, lora_path),
),
(
"The capital of France is",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("lora-test-3", 1, lora_path),
),
]

View File

@ -27,9 +27,7 @@ def create_test_prompts(
return [
(
"A robot may not injure a human being",
SamplingParams(
temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
None,
),
(
@ -41,22 +39,12 @@ def create_test_prompts(
),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(
temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
),
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("sql-lora", 1, lora_path),
),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(
temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
),
SamplingParams(temperature=0.0, logprobs=1, max_tokens=128),
LoRARequest("sql-lora2", 2, lora_path),
),
]

View File

@ -28,13 +28,11 @@ Dependencies:
- openai
"""
import base64
import io
import torch
import transformers
from openai import OpenAI
from vllm.utils.serial_utils import tensor2base64
def main():
client = OpenAI(
@ -58,11 +56,7 @@ def main():
prompt_embeds = embedding_layer(token_ids).squeeze(0)
# Prompt embeddings
buffer = io.BytesIO()
torch.save(prompt_embeds, buffer)
buffer.seek(0)
binary_data = buffer.read()
encoded_embeds = base64.b64encode(binary_data).decode("utf-8")
encoded_embeds = tensor2base64(prompt_embeds)
completion = client.completions.create(
model=model_name,

View File

@ -150,7 +150,8 @@ def run_siglip(client: OpenAI, model: str):
Start the server using:
vllm serve google/siglip-base-patch16-224 \
--runner pooling
--runner pooling \
--chat-template template_basic.jinja
"""
response = create_chat_embeddings(

View File

@ -46,6 +46,7 @@ scipy # Required for phi-4-multimodal-instruct
ninja # Required for xgrammar, rocm, tpu, xpu
pybase64 # fast base64 implementation
cbor2 # Required for cross-language serialization of hashable objects
ijson # Required for mistral streaming tool parser
setproctitle # Used to set process names for better debugging and monitoring
openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic == 0.71.0

View File

@ -3,7 +3,6 @@ ninja
packaging>=24.2
setuptools>=77.0.3,<81.0.0
setuptools-scm>=8
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"
scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL)

View File

@ -4,7 +4,6 @@
numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding
# Dependencies for CPUs
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x"
torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64"

View File

@ -42,6 +42,6 @@ tritonclient==2.51.0
numba == 0.61.2 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs]==0.15.0
runai-model-streamer[s3,gcs]==0.15.3
fastsafetensors>=0.1.10
pydantic>=2.12 # 2.11 leads to error on python 3.13

View File

@ -12,7 +12,7 @@ tensorizer==2.10.1
packaging>=24.2
setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
runai-model-streamer[s3,gcs]==0.15.0
runai-model-streamer[s3,gcs]==0.15.3
conch-triton-kernels==1.2.1
timm>=1.0.17
fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@d6f998a03432b2452f8de2bb5cefb5af9795d459

View File

@ -51,7 +51,7 @@ tritonclient==2.51.0
arctic-inference == 0.1.1 # Required for suffix decoding test
numba == 0.61.2 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs]==0.15.0
runai-model-streamer[s3,gcs]==0.15.3
fastsafetensors>=0.1.10
pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0

View File

@ -965,11 +965,11 @@ rsa==4.9.1
# via google-auth
rtree==1.4.0
# via torchgeo
runai-model-streamer==0.15.0
runai-model-streamer==0.15.3
# via -r requirements/test.in
runai-model-streamer-gcs==0.15.0
runai-model-streamer-gcs==0.15.3
# via runai-model-streamer
runai-model-streamer-s3==0.15.0
runai-model-streamer-s3==0.15.3
# via runai-model-streamer
s3transfer==0.10.3
# via boto3

View File

@ -346,10 +346,13 @@ class precompiled_wheel_utils:
The order of preference is:
1. user-specified wheel location (can be either local or remote, via
VLLM_PRECOMPILED_WHEEL_LOCATION)
2. user-specified variant from nightly repo (current main commit via
VLLM_PRECOMPILED_WHEEL_VARIANT)
2. user-specified variant (VLLM_PRECOMPILED_WHEEL_VARIANT) from nightly repo
3. the variant corresponding to VLLM_MAIN_CUDA_VERSION from nightly repo
4. the default variant from nightly repo (current main commit)
4. the default variant from nightly repo
If downloading from the nightly repo, the commit can be specified via
VLLM_PRECOMPILED_WHEEL_COMMIT; otherwise, the head commit in the main branch
is used.
"""
wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
if wheel_location is not None:
@ -362,10 +365,13 @@ class precompiled_wheel_utils:
# try to fetch the wheel metadata from the nightly wheel repo
main_variant = "cu" + envs.VLLM_MAIN_CUDA_VERSION.replace(".", "")
variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", main_variant)
commit = os.getenv(
"VLLM_PRECOMPILED_WHEEL_COMMIT",
precompiled_wheel_utils.get_base_commit_in_main_branch(),
)
commit = os.getenv("VLLM_PRECOMPILED_WHEEL_COMMIT", "").lower()
if not commit or len(commit) != 40:
print(
f"VLLM_PRECOMPILED_WHEEL_COMMIT not valid: {commit}"
", trying to fetch base commit in main branch"
)
commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
print(f"Using precompiled wheel commit {commit} with variant {variant}")
try_default = False
wheels, repo_url, download_filename = None, None, None
@ -461,14 +467,22 @@ class precompiled_wheel_utils:
"vllm/cumem_allocator.abi3.so",
]
compiled_regex = re.compile(
flash_attn_regex = re.compile(
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
)
triton_kernels_regex = re.compile(
r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
)
file_members = list(
filter(lambda x: x.filename in files_to_copy, wheel.filelist)
)
file_members += list(
filter(lambda x: compiled_regex.match(x.filename), wheel.filelist)
filter(lambda x: flash_attn_regex.match(x.filename), wheel.filelist)
)
file_members += list(
filter(
lambda x: triton_kernels_regex.match(x.filename), wheel.filelist
)
)
for file in file_members:
@ -494,10 +508,6 @@ class precompiled_wheel_utils:
@staticmethod
def get_base_commit_in_main_branch() -> str:
# Force to use the nightly wheel. This is mainly used for CI testing.
if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
return "nightly"
try:
# Get the latest commit hash of the upstream main branch.
resp_json = subprocess.check_output(
@ -508,6 +518,7 @@ class precompiled_wheel_utils:
]
).decode("utf-8")
upstream_main_commit = json.loads(resp_json)["sha"]
print(f"Upstream main branch latest commit: {upstream_main_commit}")
# In Docker build context, .git may be immutable or missing.
if envs.VLLM_DOCKER_BUILD_CONTEXT:
@ -648,7 +659,7 @@ def get_vllm_version() -> str:
if envs.VLLM_TARGET_DEVICE == "empty":
version += f"{sep}empty"
elif _is_cuda():
if envs.VLLM_USE_PRECOMPILED:
if envs.VLLM_USE_PRECOMPILED and not envs.VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX:
version += f"{sep}precompiled"
else:
cuda_version = str(get_nvcc_cuda_version())
@ -786,7 +797,7 @@ setup(
"bench": ["pandas", "matplotlib", "seaborn", "datasets"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.15.0"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.15.3"],
"audio": [
"librosa",
"soundfile",

View File

@ -392,39 +392,48 @@ def test_pass_config_deprecation(caplog_vllm):
assert "enable_fusion is deprecated" in caplog_vllm.text
assert config.fuse_norm_quant is True
assert config.fuse_act_quant is True
assert config.enable_fusion is None
assert config.enable_fusion is True
# Test enable_attn_fusion -> fuse_attn_quant
caplog_vllm.clear()
config = PassConfig(enable_attn_fusion=True)
assert "enable_attn_fusion is deprecated" in caplog_vllm.text
assert config.fuse_attn_quant is True
assert config.enable_attn_fusion is None
assert config.enable_attn_fusion is True
# Test enable_noop -> eliminate_noops
caplog_vllm.clear()
config = PassConfig(enable_noop=True)
assert "enable_noop is deprecated" in caplog_vllm.text
assert config.eliminate_noops is True
assert config.enable_noop is None
assert config.enable_noop is True
# Test enable_sequence_parallelism -> enable_sp
caplog_vllm.clear()
config = PassConfig(enable_sequence_parallelism=True)
assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
assert config.enable_sp is True
assert config.enable_sequence_parallelism is None
assert config.enable_sequence_parallelism is True
# Test enable_async_tp -> fuse_gemm_comms
caplog_vllm.clear()
config = PassConfig(enable_async_tp=True)
assert "enable_async_tp is deprecated" in caplog_vllm.text
assert config.fuse_gemm_comms is True
assert config.enable_async_tp is None
assert config.enable_async_tp is True
# Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
caplog_vllm.clear()
config = PassConfig(enable_fi_allreduce_fusion=True)
assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
assert config.fuse_allreduce_rms is True
assert config.enable_fi_allreduce_fusion is None
assert config.enable_fi_allreduce_fusion is True
# Test hash consistency
config_old = PassConfig(enable_fusion=True)
config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
assert config_old.compute_hash() == config_new.compute_hash()
config_old = PassConfig(enable_async_tp=True)
config_new = PassConfig(fuse_gemm_comms=True)
assert config_old.compute_hash() == config_new.compute_hash()

View File

@ -6,6 +6,7 @@ import lm_eval
import pytest
from tests.utils import large_gpu_mark
from vllm.platforms import current_platform
def get_model_args(
@ -45,6 +46,12 @@ def get_model_args(
return model_args
pytestmark = pytest.mark.skipif(
current_platform.is_rocm(),
reason="EPLB with Spec Decode is a work in progress on ROCm.",
)
@pytest.mark.parametrize(
"model_setup",
[

View File

@ -69,9 +69,20 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic):
stream=True,
)
first_chunk = None
chunk_count = 0
async for chunk in resp:
chunk_count += 1
if first_chunk is None and chunk.type == "message_start":
first_chunk = chunk
print(chunk.model_dump_json())
assert chunk_count > 0
assert first_chunk is not None, "message_start chunk was never observed"
assert first_chunk.usage is not None, "first chunk should include usage stats"
assert first_chunk.usage["output_tokens"] == 0
assert first_chunk.usage["input_tokens"] > 5
@pytest.mark.asyncio
async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic):

View File

@ -2,64 +2,47 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import numpy as np
import pytest
import requests
import torch
from vllm.utils.serial_utils import tensor2base64
from ...utils import RemoteOpenAIServer
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16"
def _terratorch_dummy_inputs(model_name: str):
def _terratorch_dummy_messages():
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
buffer_tiff = io.BytesIO()
torch.save(pixel_values, buffer_tiff)
buffer_tiff.seek(0)
binary_data = buffer_tiff.read()
base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8")
buffer_coord = io.BytesIO()
torch.save(location_coords, buffer_coord)
buffer_coord.seek(0)
binary_data = buffer_coord.read()
base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8")
return {
"model": model_name,
"additional_data": {"prompt_token_ids": [1]},
"encoding_format": "base64",
"messages": [
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"pixel_values": base64_tensor_embedding,
"location_coords": base64_coord_embedding,
},
}
],
}
],
}
return [
{
"role": "user",
"content": [
{
"type": "image_embeds",
"image_embeds": {
"pixel_values": tensor2base64(pixel_values),
"location_coords": tensor2base64(location_coords),
},
}
],
}
]
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_request(model_name: str):
@pytest.mark.parametrize(
"model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
)
def test_single_request(model_name: str):
args = [
"--runner",
"pooling",
# use half precision for speed and memory savings in CI environment
"--dtype",
DTYPE,
"float16",
"--enforce-eager",
"--trust-remote-code",
"--max-num-seqs",
@ -70,11 +53,15 @@ async def test_single_request(model_name: str):
"--enable-mm-embeds",
]
with RemoteOpenAIServer(MODEL_NAME, args) as server:
prompt = _terratorch_dummy_inputs(model_name)
# test single pooling
response = requests.post(server.url_for("pooling"), json=prompt)
with RemoteOpenAIServer(model_name, args) as server:
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"messages": _terratorch_dummy_messages(),
"encoding_format": "base64",
},
)
response.raise_for_status()
output = response.json()["data"][0]["data"]

View File

@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM):
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
# chunked prefill does not support all pooling
err_msg = "pooling_task must be one of.+"
with pytest.raises(ValueError, match=err_msg):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_token_classify(llm: LLM):
llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
def test_score_api(llm: LLM):

View File

@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
# token_classify uses ALL pooling, which does not support chunked prefill.
task = "token_classify"
input_text = ["This product was excellent and exceeded my expectations"]
response = requests.post(
server.url_for("pooling"),
json={
"model": model_name,
"input": "test",
"input": input_text,
"encoding_format": "float",
"task": task,
},
)
assert response.json()["error"]["type"] == "BadRequestError"
assert response.json()["error"]["message"].startswith(
f"Task {task} is not supported"
)
poolings = PoolingResponse.model_validate(response.json())
assert len(poolings.data) == 1
assert len(poolings.data[0].data) == 8
assert len(poolings.data[0].data[0]) == 2
@pytest.mark.asyncio

View File

@ -42,7 +42,7 @@ def llm():
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
def test_token_embed(llm: LLM):
outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
multi_vector = outputs[0].outputs.data
assert multi_vector.shape == (11, 384)

View File

@ -36,6 +36,13 @@ def llm():
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
def test_config(llm: LLM):
vllm_config = llm.llm_engine.vllm_config
assert vllm_config.cache_config.enable_prefix_caching
assert vllm_config.scheduler_config.enable_chunked_prefill
def test_pooling_params(llm: LLM):
def get_outputs(use_activation):
outputs = llm.reward(

View File

@ -29,6 +29,7 @@ from vllm.multimodal.utils import (
encode_video_base64,
)
from vllm.tokenizers import MistralTokenizer, get_tokenizer
from vllm.utils.serial_utils import tensor2base64
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
@ -85,11 +86,6 @@ def phi3v_model_config_image_embeds():
)
@pytest.fixture(scope="module")
def phi3v_tokenizer():
return get_tokenizer(PHI3V_MODEL_ID)
@pytest.fixture(scope="function")
def qwen2_audio_model_config():
return ModelConfig(
@ -115,11 +111,6 @@ def audio_embeds_model_config():
)
@pytest.fixture(scope="module")
def qwen2_audio_tokenizer():
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
@pytest.fixture(scope="function")
def qwen25omni_model_config_mm_interleaved():
return ModelConfig(
@ -134,11 +125,6 @@ def qwen25omni_model_config_mm_interleaved():
)
@pytest.fixture(scope="module")
def qwen25omni_tokenizer():
return get_tokenizer(QWEN25OMNI_MODEL_ID)
@pytest.fixture(scope="function")
def mistral_model_config():
return ModelConfig(
@ -150,11 +136,6 @@ def mistral_model_config():
)
@pytest.fixture(scope="module")
def mistral_tokenizer():
return get_tokenizer(MISTRAL_MODEL_ID)
@pytest.fixture(scope="module")
def image_url():
image = ImageAsset("cherry_blossom")
@ -239,7 +220,6 @@ def _assert_mm_data_inputs(
def test_parse_chat_messages_single_image(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -253,7 +233,6 @@ def test_parse_chat_messages_single_image(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -266,7 +245,6 @@ def test_parse_chat_messages_single_image(
def test_parse_chat_messages_single_image_with_uuid(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
@ -287,7 +265,6 @@ def test_parse_chat_messages_single_image_with_uuid(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -300,7 +277,6 @@ def test_parse_chat_messages_single_image_with_uuid(
def test_parse_chat_messages_single_empty_image_with_uuid(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
@ -319,7 +295,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -332,7 +307,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid(
def test_parse_chat_messages_single_image_with_bad_uuid_format(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
@ -354,7 +328,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -367,7 +340,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format(
def test_parse_chat_messages_multiple_images_with_uuids(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
@ -397,7 +369,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -413,7 +384,6 @@ def test_parse_chat_messages_multiple_images_with_uuids(
def test_parse_chat_messages_multiple_empty_images_with_uuids(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
@ -439,7 +409,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -455,7 +424,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids(
def test_parse_chat_messages_mixed_empty_images_with_uuids(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
@ -483,7 +451,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -500,7 +467,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids(
@pytest.mark.asyncio
async def test_parse_chat_messages_single_image_with_uuid_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
@ -519,7 +485,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -533,7 +498,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_with_uuid_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
@ -552,7 +516,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -566,7 +529,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_uuids_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
@ -592,7 +554,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -609,7 +570,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid1 = "my_uuid_1"
@ -635,7 +595,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -652,7 +611,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid2 = "my_uuid_2"
@ -676,7 +634,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -692,7 +649,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async(
def test_parse_chat_messages_empty_system(
mistral_model_config,
mistral_tokenizer,
):
# Test string format
conversation, _, _ = parse_chat_messages(
@ -704,7 +660,6 @@ def test_parse_chat_messages_empty_system(
},
],
mistral_model_config,
mistral_tokenizer,
content_format="string",
)
assert conversation == [
@ -722,7 +677,6 @@ def test_parse_chat_messages_empty_system(
},
],
mistral_model_config,
mistral_tokenizer,
content_format="openai",
)
assert conversation == [
@ -734,7 +688,6 @@ def test_parse_chat_messages_empty_system(
@pytest.mark.asyncio
async def test_parse_chat_messages_single_image_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
@ -748,7 +701,6 @@ async def test_parse_chat_messages_single_image_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -761,7 +713,6 @@ async def test_parse_chat_messages_single_image_async(
def test_parse_chat_messages_multiple_images(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -779,7 +730,6 @@ def test_parse_chat_messages_multiple_images(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -795,7 +745,6 @@ def test_parse_chat_messages_multiple_images(
def test_parse_chat_messages_empty_pil_image_with_uuid(
phi3v_model_config,
phi3v_tokenizer,
):
uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -809,7 +758,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -825,7 +773,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid(
def test_parse_chat_messages_empty_image_embeds_with_uuid(
phi3v_model_config_image_embeds,
phi3v_tokenizer,
):
uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -839,7 +786,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
}
],
phi3v_model_config_image_embeds,
phi3v_tokenizer,
content_format="string",
)
@ -857,7 +803,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
def test_parse_chat_messages_empty_audio_embeds_with_uuid(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with UUID (no actual embeds data)."""
uuid = "test-audio-uuid-123"
@ -873,7 +818,6 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
@ -889,11 +833,8 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
def test_parse_chat_messages_audio_embeds_with_string(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with base64 string embedding data."""
import base64
import io
import torch
@ -901,11 +842,7 @@ def test_parse_chat_messages_audio_embeds_with_string(
audio_embedding = torch.randn(1, 128, 768)
# Encode it as base64
buffer = io.BytesIO()
torch.save(audio_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
base64_audio_embedding = tensor2base64(audio_embedding)
conversation, mm_data, mm_uuids = parse_chat_messages(
[
@ -921,7 +858,6 @@ def test_parse_chat_messages_audio_embeds_with_string(
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
@ -939,11 +875,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
@pytest.mark.asyncio
async def test_parse_chat_messages_audio_embeds_async(
audio_embeds_model_config,
qwen2_audio_tokenizer,
):
"""Test audio_embeds with async futures."""
import base64
import io
import torch
@ -951,11 +884,7 @@ async def test_parse_chat_messages_audio_embeds_async(
audio_embedding = torch.randn(1, 128, 768)
# Encode it as base64
buffer = io.BytesIO()
torch.save(audio_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
base64_audio_embedding = tensor2base64(audio_embedding)
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
[
@ -971,7 +900,6 @@ async def test_parse_chat_messages_audio_embeds_async(
}
],
audio_embeds_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
@ -990,7 +918,6 @@ async def test_parse_chat_messages_audio_embeds_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
phi3v_model_config_image_embeds,
phi3v_tokenizer,
):
uuid = "abcd"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
@ -1004,7 +931,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
}
],
phi3v_model_config_image_embeds,
phi3v_tokenizer,
content_format="string",
)
@ -1024,7 +950,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_async(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
@ -1042,7 +967,6 @@ async def test_parse_chat_messages_multiple_images_async(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -1058,7 +982,6 @@ async def test_parse_chat_messages_multiple_images_async(
def test_parse_chat_messages_placeholder_already_in_prompt(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -1076,7 +999,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [
@ -1091,7 +1013,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
def test_parse_chat_messages_placeholder_one_already_in_prompt(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -1110,7 +1031,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -1127,7 +1047,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
def test_parse_chat_messages_multiple_images_across_messages(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -1149,7 +1068,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
},
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -1164,7 +1082,6 @@ def test_parse_chat_messages_multiple_images_across_messages(
def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
@ -1195,7 +1112,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
},
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -1210,7 +1126,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages(
def test_parse_chat_messages_context_text_format(
phi3v_model_config,
phi3v_tokenizer,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
[
@ -1222,7 +1137,6 @@ def test_parse_chat_messages_context_text_format(
{"role": "user", "content": "What about this one?"},
],
phi3v_model_config,
phi3v_tokenizer,
content_format="openai",
)
@ -1246,7 +1160,6 @@ def test_parse_chat_messages_context_text_format(
def test_parse_chat_messages_rejects_too_many_images_in_one_message(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
with warnings.catch_warnings():
@ -1277,14 +1190,12 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
def test_parse_chat_messages_rejects_too_many_images_across_messages(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
with warnings.catch_warnings():
@ -1322,14 +1233,12 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
},
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -1344,7 +1253,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
}
],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
@ -1360,7 +1268,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
def test_parse_chat_messages_multiple_images_interleave(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -1380,7 +1287,6 @@ def test_parse_chat_messages_multiple_images_interleave(
}
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
@ -1398,7 +1304,6 @@ def test_parse_chat_messages_multiple_images_interleave(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_interleave_async(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages_futures(
@ -1418,7 +1323,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
}
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
@ -1436,7 +1340,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async(
@pytest.mark.asyncio
async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
@ -1465,7 +1368,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
}
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
@ -1482,7 +1384,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async(
def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -1505,7 +1406,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
},
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
@ -1523,7 +1423,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave(
def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
image_uuid = str(hash(image_url))
@ -1555,7 +1454,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
},
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
@ -1573,7 +1471,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl
def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
@ -1601,7 +1498,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
@ -1627,7 +1523,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave(
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave(
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
@ -1671,7 +1566,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
@ -1699,7 +1593,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl
def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
@ -1743,7 +1636,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
@ -1775,7 +1667,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes
def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
image_url,
video_url,
audio_url,
@ -1811,7 +1702,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
},
],
qwen25omni_model_config_mm_interleaved,
qwen25omni_tokenizer,
content_format="string",
)
@ -1837,7 +1727,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message
def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
image_url,
):
with pytest.raises(
@ -1861,7 +1750,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders(
}
],
phi3v_model_config_mm_interleaved,
phi3v_tokenizer,
content_format="string",
)
@ -2237,9 +2125,7 @@ def test_resolve_content_format_examples(template_path, expected_format):
assert resolved_format == expected_format
def test_parse_chat_messages_include_thinking_chunk(
mistral_model_config, mistral_tokenizer
):
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config):
messages = [
{
"role": "system",
@ -2269,7 +2155,6 @@ def test_parse_chat_messages_include_thinking_chunk(
conversation_with_thinking, _, _ = parse_chat_messages(
messages,
mistral_model_config,
mistral_tokenizer,
content_format="openai",
)
@ -2353,7 +2238,6 @@ def test_apply_mistral_chat_template_thinking_chunk():
def test_parse_chat_messages_single_empty_audio_with_uuid(
qwen2_audio_model_config,
qwen2_audio_tokenizer,
):
audio_uuid = "abcd"
conversation, mm_data, mm_uuids = parse_chat_messages(
@ -2371,7 +2255,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
}
],
qwen2_audio_model_config,
qwen2_audio_tokenizer,
content_format="string",
)
@ -2389,7 +2272,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid(
@pytest.mark.asyncio
async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
qwen2_audio_model_config,
qwen2_audio_tokenizer,
):
audio_uuid = "abcd"
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
@ -2407,7 +2289,6 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async(
}
],
qwen2_audio_model_config,
qwen2_audio_tokenizer,
content_format="string",
)

View File

@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
BatchedTritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
needs_deep_gemm=True,
)
register_experts(
TritonOrDeepGemmExperts,
standard_format,
@ -457,10 +444,6 @@ def make_fused_experts(
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print(f"Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor,
silu_mul_per_token_group_quant_fp8_colmajor,
)
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
FLOAT8_DTYPE = torch.float8_e4m3fn
GROUP_SIZE = 128
def reference_quant(x: torch.Tensor, use_ue8m0: bool):
"""
Reference triton quant kernel from,
vllm.model_executor.layers.quantization.utils.fp8_utils
"""
x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE)
# Allocate the scale tensor in column-major format.
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
M = x.numel() // GROUP_SIZE
N = GROUP_SIZE
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
finfo = torch.finfo(FLOAT8_DTYPE)
fp8_min = finfo.min
fp8_max = finfo.max
_per_token_group_quant_fp8_colmajor[(M,)](
x,
x_q,
x_s,
GROUP_SIZE,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps=1e-10,
fp8_min=fp8_min,
fp8_max=fp8_max,
use_ue8m0=use_ue8m0,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
return x_q, x_s
def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
T, N = x.size()
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
torch.ops._C.silu_and_mul(ref_act_out, x)
return reference_quant(ref_act_out, use_ue8m0)
@pytest.mark.parametrize("T", [128, 256, 512])
@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2])
def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int):
current_platform.seed_everything(42)
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
use_ue8m0 = is_deep_gemm_e8m0_used()
# Test
output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor(
input, use_ue8m0=use_ue8m0
)
# Reference
ref_output, ref_output_scales = reference(input, use_ue8m0)
torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32))
torch.testing.assert_close(output_scales, ref_output_scales)

View File

@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.executor import UniProcExecutor
from vllm.v1.worker.worker_base import WorkerWrapperBase
# This is a dummy executor for patching in test_runai_model_streamer_s3.py.
# We cannot use vllm_runner fixture here, because it spawns worker process.
# The worker process reimports the patched entities, and the patch is not applied.
class RunaiDummyExecutor(UniProcExecutor):
def _init_executor(self) -> None:
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())
local_rank = 0
rank = 0
is_driver_worker = True
device_info = self.vllm_config.device_config.device.__str__().split(":")
if len(device_info) > 1:
local_rank = int(device_info[1])
worker_rpc_kwargs = dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)
wrapper_kwargs = {
"vllm_config": self.vllm_config,
}
self.driver_worker = WorkerWrapperBase(**wrapper_kwargs)
self.collective_rpc("init_worker", args=([worker_rpc_kwargs],))
self.collective_rpc("init_device")

View File

@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from huggingface_hub import snapshot_download
from runai_model_streamer.safetensors_streamer.streamer_mock import StreamerPatcher
from vllm.engine.arg_utils import EngineArgs
from .conftest import RunaiDummyExecutor
load_format = "runai_streamer"
test_model = "openai-community/gpt2"
def test_runai_model_loader_download_files_s3_mocked_with_patch(
vllm_runner,
tmp_path: Path,
monkeypatch,
):
patcher = StreamerPatcher(str(tmp_path))
test_mock_s3_model = "s3://my-mock-bucket/gpt2/"
# Download model from HF
mock_model_dir = f"{tmp_path}/gpt2"
snapshot_download(repo_id=test_model, local_dir=mock_model_dir)
monkeypatch.setattr(
"vllm.transformers_utils.runai_utils.runai_list_safetensors",
patcher.shim_list_safetensors,
)
monkeypatch.setattr(
"vllm.transformers_utils.runai_utils.runai_pull_files",
patcher.shim_pull_files,
)
monkeypatch.setattr(
"vllm.model_executor.model_loader.weight_utils.SafetensorsStreamer",
patcher.create_mock_streamer,
)
engine_args = EngineArgs(
model=test_mock_s3_model,
load_format=load_format,
tensor_parallel_size=1,
)
vllm_config = engine_args.create_engine_config()
executor = RunaiDummyExecutor(vllm_config)
executor.driver_worker.load_model()

View File

@ -0,0 +1,53 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from transformers import AutoModel
from tests.models.utils import check_embeddings_close
from vllm import TokensPrompt
@pytest.mark.parametrize(
"model",
["Qwen/Qwen3-Embedding-0.6B"],
)
@torch.inference_mode
def test_embed_models(hf_runner, vllm_runner, model: str):
chunk_size = 10
n_prompt_tokens = [55, 56, 57]
token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens]
with vllm_runner(
model,
runner="pooling",
max_model_len=128,
max_num_batched_tokens=chunk_size,
enforce_eager=True,
# `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner
enable_chunked_prefill=True,
enable_prefix_caching=True,
) as vllm_model:
vllm_outputs = vllm_model.token_embed(
[TokensPrompt(prompt_token_ids=t) for t in token_prompts],
)
with hf_runner(
model,
auto_cls=AutoModel,
) as hf_model:
hf_outputs = []
for token_prompt in token_prompts:
inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])})
input_ids = inputs["input_ids"]
output = hf_model.model(input_ids)
hf_outputs.append(output.last_hidden_state.cpu().float()[0])
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
check_embeddings_close(
embeddings_0_lst=hf_output,
embeddings_1_lst=vllm_output,
name_0="hf",
name_1="vllm",
tol=1e-2,
)

View File

@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str):
max_model_len=128,
enforce_eager=True,
runner="pooling",
enable_chunked_prefill=False,
enable_prefix_caching=True,
) as vllm_model:
pooling_outputs = vllm_model.llm.encode(

View File

@ -2,6 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM tests."""
import warnings
import torch
from vllm.platforms import current_platform
@ -14,6 +16,20 @@ def pytest_configure(config):
if not current_platform.is_rocm():
return
skip_patterns = ["test_granite_speech.py"]
if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns):
# Skip disabling SDP for Granite Speech tests on ROCm
return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
warnings.warn(
"ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp "
"to avoid HuggingFace Transformers accuracy issues",
UserWarning,
stacklevel=1,
)

View File

@ -403,12 +403,13 @@ VLM_TEST_SETTINGS = {
# So, we need to reduce the number of tokens for the test to pass.
max_tokens=8,
num_logprobs=10,
auto_cls=AutoModelForCausalLM,
marks=[large_gpu_mark(min_gb=32)],
),
"glm4_1v": VLMTestInfo(
models=["zai-org/GLM-4.1V-9B-Thinking"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>",
prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>",
video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>",
max_model_len=2048,
@ -423,6 +424,7 @@ VLM_TEST_SETTINGS = {
models=["zai-org/GLM-4.1V-9B-Thinking"],
# GLM4.1V require include video metadata for input
test_type=VLMTestType.CUSTOM_INPUTS,
prompt_formatter=lambda vid_prompt: f"[gMASK]<|user|>\n{vid_prompt}<|assistant|>\n", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
@ -737,7 +739,13 @@ VLM_TEST_SETTINGS = {
max_model_len=8192,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
marks=[large_gpu_mark(min_gb=48)],
marks=[
large_gpu_mark(min_gb=48),
pytest.mark.skipif(
current_platform.is_rocm(),
reason="Model produces a vector of <UNK> output in HF on ROCm",
),
],
),
"qwen_vl": VLMTestInfo(
models=["Qwen/Qwen-VL"],

View File

@ -8,6 +8,7 @@ from transformers import AutoModelForSpeechSeq2Seq
from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner
from ...registry import HF_EXAMPLE_MODELS
@ -34,6 +35,12 @@ audio_lora_path = MODEL_NAME
models = [MODEL_NAME]
@pytest.fixture(autouse=True)
def set_attention_backend_for_rocm(monkeypatch):
if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
@ -111,8 +118,12 @@ def run_test(
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_model_len", [2048])
@pytest.mark.parametrize(
"dtype", ["float16"] if current_platform.is_rocm() else ["bfloat16"]
)
@pytest.mark.parametrize(
"max_model_len", [512] if current_platform.is_rocm() else [2048]
)
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(

View File

@ -1,281 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Sequence
import librosa
import pytest
from huggingface_hub import snapshot_download
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import rescale_image_size
from ....conftest import (
IMAGE_ASSETS,
HfRunner,
PromptAudioInput,
PromptImageInput,
VllmRunner,
)
from ....utils import large_gpu_test
from ...utils import check_logprobs_close
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts(
{
"stop_sign": "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501
"cherry_blossom": "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501
}
)
HF_MULTIIMAGE_IMAGE_PROMPT = (
"<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501
)
model_path = snapshot_download(
"microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70"
)
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
speech_question = os.path.join(
model_path, "examples", "what_is_shown_in_this_image.wav"
)
models = [model_path]
target_dtype = "half"
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
inputs: Sequence[tuple[list[str], PromptImageInput, PromptAudioInput | None]],
model: str,
*,
max_model_len: int,
dtype: str,
max_tokens: int,
num_logprobs: int,
mm_limit: int,
tensor_parallel_size: int,
distributed_executor_backend: str | None = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(
model,
task="generate",
max_model_len=max_model_len,
max_num_seqs=2,
dtype=dtype,
limit_mm_per_prompt={"image": mm_limit},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enable_lora=True,
max_lora_rank=320,
gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI
enforce_eager=True,
trust_remote_code=False,
) as vllm_model:
lora_request = LoRARequest("vision", 1, vision_lora_path)
vllm_outputs_per_case = [
vllm_model.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
audios=audios,
lora_request=lora_request,
)
for prompts, images, audios in inputs
]
with hf_runner(model, dtype=dtype) as hf_model:
hf_model.model.load_adapter(
vision_lora_path,
adapter_name="vision",
)
hf_processor = hf_model.processor
eos_token_id = hf_processor.tokenizer.eos_token_id
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
audios=audios,
eos_token_id=eos_token_id,
)
for prompts, images, audios in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [12800])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors,
dtype: str,
max_model_len: int,
max_tokens: int,
num_logprobs: int,
) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [
(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
None,
)
for image, prompt in zip(images, HF_IMAGE_PROMPTS)
]
run_test(
hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_model_len=max_model_len,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
# [],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [25600])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_multi_images_models(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors,
dtype: str,
max_model_len: int,
max_tokens: int,
num_logprobs: int,
) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_case = [
(
[HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[
[rescale_image_size(image, factor) for image in images]
for factor in size_factors
],
None,
),
]
run_test(
hf_runner,
vllm_runner,
inputs_per_case,
model,
dtype=dtype,
max_model_len=max_model_len,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=2,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [12800])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_vision_speech_models(
hf_runner,
vllm_runner,
model,
dtype: str,
max_model_len: int,
max_tokens: int,
num_logprobs: int,
) -> None:
# use the example speech question so that the model outputs are reasonable
audio = librosa.load(speech_question, sr=16000)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
inputs_vision_speech = [
(
["<|user|><|image|><|audio|><|end|><|assistant|>"],
[image],
[audio],
),
]
run_test(
hf_runner,
vllm_runner,
inputs_vision_speech,
model,
dtype=dtype,
max_model_len=max_model_len,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)

View File

@ -15,6 +15,7 @@ from transformers import AutoProcessor
from vllm import SamplingParams, TextPrompt, TokensPrompt
from vllm.logprobs import Logprob, SampleLogprobs
from vllm.multimodal import MultiModalDataBuiltins
from vllm.platforms import current_platform
from ....utils import VLLM_PATH, large_gpu_test
from ...utils import check_logprobs_close
@ -165,6 +166,15 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
def test_chat(
vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server
) -> None:
if (
model == MISTRAL_SMALL_3_1_ID
and max_model_len == 65536
and current_platform.is_rocm()
):
pytest.skip(
"OOM on ROCm: 24B model with 65536 context length exceeds GPU memory"
)
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model])
with vllm_runner(
model,

View File

@ -140,7 +140,7 @@ def video_with_metadata_glm4_1v():
metadata = VIDEO_ASSETS[0].metadata
question = "Describe the video."
video_prompt = "<|begin_of_video|><|video|><|end_of_video|>"
formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n"
formatted_prompt = f"[gMASK]<|user|>\n{video_prompt}{question}<|assistant|>\n"
scales = [0.1, 0.2, 0.25]
video_input = [

View File

@ -25,6 +25,7 @@ from transformers import (
from transformers.video_utils import VideoMetadata
from vllm.logprobs import SampleLogprobs
from vllm.platforms import current_platform
from vllm.utils.collection_utils import is_list_of
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
@ -366,6 +367,40 @@ def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOut
def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner to use for GLM4V."""
if current_platform.is_rocm():
import types
config = hf_model.model.config
if hasattr(config, "num_layers") and not hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = config.num_layers
config.output_hidden_states = True
def patched_prepare_cache(
self, generation_config, model_kwargs, *args, **kwargs
):
model_kwargs["past_key_values"] = None
model_kwargs["use_cache"] = False
return model_kwargs
hf_model.model._prepare_cache_for_generation = types.MethodType(
patched_prepare_cache, hf_model.model
)
original_generate = hf_model.model.generate
def patched_generate(*args, **kwargs):
kwargs["output_hidden_states"] = True
kwargs["return_dict_in_generate"] = True
return original_generate(*args, **kwargs)
hf_model.model.generate = patched_generate
original_forward = hf_model.model.forward
def patched_forward(*args, **kwargs):
kwargs["output_hidden_states"] = True
return original_forward(*args, **kwargs)
hf_model.model.forward = patched_forward
hf_processor = hf_model.processor
def processor(*args, text="", images=None, **kwargs):
@ -406,7 +441,15 @@ def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
if videos is not None and is_list_of(videos, tuple):
# If videos is a list of tuples, we assume each tuple contains
# (video_array, metadata) as in the case of GLM4.1V.
video_metadata = [[VideoMetadata(**video[1])] for video in videos]
# Filter out 'do_sample_frames' as it's not a valid VideoMetadata arg
video_metadata = [
[
VideoMetadata(
**{k: v for k, v in video[1].items() if k != "do_sample_frames"}
)
]
for video in videos
]
videos = [[video[0]] for video in videos]
else:
video_metadata = None

View File

@ -0,0 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pytest configuration for vLLM pooling tests."""
import os
import warnings
from vllm.platforms import current_platform
def pytest_collection_modifyitems(config, items):
"""Set FLEX_ATTENTION backend for SigLIP tests on ROCm."""
if not current_platform.is_rocm():
return
siglip_tests = [item for item in items if "test_siglip" in item.nodeid]
if siglip_tests:
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
warnings.warn(
"ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests",
UserWarning,
stacklevel=1,
)

View File

@ -396,28 +396,6 @@ def test_processing_correctness(
)
# Phi4MultimodalForCausalLM share same model repo with original format
# Phi4MMForCausalLM, so we add it as a separate test case
# Remove this test after conversion PR merged:
# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/70
@pytest.mark.parametrize("model_arch", ["Phi4MultimodalForCausalLM"])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
def test_processing_correctness_phi4_multimodal(
model_arch: str,
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_correctness(
model_arch,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
)
def _assert_inputs_equal(
a: MultiModalInputs,
b: MultiModalInputs,

View File

@ -667,6 +667,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"moonshotai/Kimi-VL-A3B-Instruct",
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"},
trust_remote_code=True,
max_transformers_version="4.53.3",
transformers_version_reason="HF model uses deprecated transformers API "
"(PytorchGELUTanh, DynamicCache.seen_tokens, and more). See: "
"https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31",
),
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
"lightonai/LightOnOCR-1B",
@ -767,10 +771,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Phi4MMForCausalLM": _HfExamplesInfo(
"microsoft/Phi-4-multimodal-instruct", trust_remote_code=True
),
"Phi4MultimodalForCausalLM": _HfExamplesInfo(
"microsoft/Phi-4-multimodal-instruct",
revision="refs/pr/70",
),
"PixtralForConditionalGeneration": _HfExamplesInfo(
"mistralai/Pixtral-12B-2409",
extras={

View File

@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods:
"""Test the is_reasoning_end method."""
parser = TestThinkingReasoningParser(test_tokenizer)
end_token_id = parser.end_token_id
start_token_id = parser.start_token_id
# Test with end token present
assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True
@ -122,6 +122,16 @@ class TestBaseThinkingReasoningParserMethods:
# Test with empty list
assert parser.is_reasoning_end([]) is False
# Test with interleaved thinking
assert parser.is_reasoning_end([1, start_token_id, 2, end_token_id]) is True
assert parser.is_reasoning_end([1, start_token_id, 2, 3]) is False
assert (
parser.is_reasoning_end(
[1, start_token_id, 2, end_token_id, 2, 2, start_token_id]
)
is False
)
def test_extract_content_ids(self, test_tokenizer):
"""Test the extract_content_ids method."""
parser = TestThinkingReasoningParser(test_tokenizer)

View File

@ -5,6 +5,10 @@
set -e
set -x
merge_base_commit=$(git merge-base HEAD origin/main)
echo "Current merge base commit with main: $merge_base_commit"
git show --oneline -s $merge_base_commit
cd /vllm-workspace/
# uninstall vllm
@ -18,7 +22,7 @@ apt autoremove -y
echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py
VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e .
VLLM_PRECOMPILED_WHEEL_COMMIT=$merge_base_commit VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e .
# Run the script
python3 -c 'import vllm'

View File

@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files):
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support chunked prefill.",
True,
"Pooling models with causal attn and all pooling support chunked prefill.",
),
(
"BAAI/bge-base-en",
@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported(
(
"internlm/internlm2-1_8b-reward",
"decoder",
False,
"Pooling models with all pooling does not support prefix caching.",
True,
"Pooling models with causal attn and all pooling support prefix caching.",
),
(
"BAAI/bge-base-en",

View File

@ -365,3 +365,54 @@ class TestEnvSetWithChoices:
with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}):
env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"])
assert env_func() == {"option1", "option2"}
class TestVllmConfigureLogging:
"""Test cases for VLLM_CONFIGURE_LOGGING environment variable."""
def test_configure_logging_defaults_to_true(self):
"""Test that VLLM_CONFIGURE_LOGGING defaults to True when not set."""
# Ensure the env var is not set
with patch.dict(os.environ, {}, clear=False):
if "VLLM_CONFIGURE_LOGGING" in os.environ:
del os.environ["VLLM_CONFIGURE_LOGGING"]
# Clear cache if it exists
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
result = envs.VLLM_CONFIGURE_LOGGING
assert result is True
assert isinstance(result, bool)
def test_configure_logging_with_zero_string(self):
"""Test that VLLM_CONFIGURE_LOGGING='0' evaluates to False."""
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "0"}):
# Clear cache if it exists
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
result = envs.VLLM_CONFIGURE_LOGGING
assert result is False
assert isinstance(result, bool)
def test_configure_logging_with_one_string(self):
"""Test that VLLM_CONFIGURE_LOGGING='1' evaluates to True."""
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "1"}):
# Clear cache if it exists
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
result = envs.VLLM_CONFIGURE_LOGGING
assert result is True
assert isinstance(result, bool)
def test_configure_logging_with_invalid_value_raises_error(self):
"""Test that invalid VLLM_CONFIGURE_LOGGING value raises ValueError."""
with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "invalid"}):
# Clear cache if it exists
if hasattr(envs.__getattr__, "cache_clear"):
envs.__getattr__.cache_clear()
with pytest.raises(ValueError, match="invalid literal for int"):
_ = envs.VLLM_CONFIGURE_LOGGING

View File

@ -0,0 +1,847 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
import partial_json_parser
import pytest
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import InstructRequest
from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser
from vllm.tokenizers import (
MistralTokenizer,
TokenizerLike,
get_tokenizer,
)
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
@pytest.fixture(scope="module")
def mistral_pre_v11_tokenizer():
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture(scope="module")
def mistral_tokenizer():
MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral")
@pytest.fixture
def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer):
return MistralToolParser(mistral_pre_v11_tokenizer)
@pytest.fixture
def mistral_tool_parser(mistral_tokenizer):
return MistralToolParser(mistral_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
expected_tool_calls: list[ToolCall],
):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) == 9
if isinstance(actual_tool_call, ToolCall):
assert actual_tool_call.type == "function"
elif isinstance(actual_tool_call, DeltaToolCall):
assert actual_tool_call.function is not None
assert actual_tool_call.function.name is not None
assert actual_tool_call.function.arguments is not None
assert actual_tool_call.function is not None
assert actual_tool_call.function.name == expected_tool_call.function.name, (
f"got wrong function name:${actual_tool_call.function.name}"
)
assert (
actual_tool_call.function.arguments == expected_tool_call.function.arguments
), f"got wrong function argument:${actual_tool_call.function.arguments}"
def fix_tool_call_tokenization(
tokens: list[int],
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: TokenizerLike,
):
"""
Replaces the textual token sequence for [TOOL_CALLS]
with its single special token ID.
"""
textual_tool_call_token_ids = mistral_tokenizer.encode(
text=mistral_tool_parser.bot_token,
add_special_tokens=False,
)
# textual_tool_call_token_ids must not contain special tokens like bos, eos etc
special_tool_call_token_ids = [mistral_tool_parser.bot_token_id]
# If the input is too short to contain the sequence, no replacement is possible
if not tokens or len(tokens) < len(textual_tool_call_token_ids):
return tokens
result_tokens = []
i = 0
target_len = len(textual_tool_call_token_ids)
while i < len(tokens):
# Check if the slice from the current position matches the target sequence
if tokens[i : i + target_len] == textual_tool_call_token_ids:
# If it matches, add the replacement and jump the index forward
result_tokens.extend(special_tool_call_token_ids)
i += target_len
else:
# Otherwise, just add the current token and move to the next one
result_tokens.append(tokens[i])
i += 1
return result_tokens
def stream_delta_message_generator(
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: TokenizerLike,
model_output: str | None,
tools: list[tuple[str, str]] | None,
) -> Generator[DeltaMessage, None, None]:
if (
isinstance(mistral_tokenizer, MistralTokenizer)
and mistral_tokenizer.version >= 11
):
# With the newer versions of the tokenizer,
# we cannot tokenize free text
# so we need to create a list of messages to get tokenized
assert tools is not None
assistant_msg = AssistantMessage(
tool_calls=[
ToolCall(
function=FunctionCall(
name=name,
arguments=arg,
)
)
for (name, arg) in tools
],
)
request = InstructRequest(
messages=[assistant_msg],
)
all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens
else:
# Older versions of the tokenizer are
# able to encode directly the model's output (free text) into tokens
assert model_output is not None
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_tool_parser, mistral_tokenizer
)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=mistral_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer),
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=None, # type: ignore[arg-type]
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser):
model_output = "This is a test"
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
None,
),
],
)
def test_extract_tool_calls_pre_v11_tokenizer(
mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
None,
),
(
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
None,
),
(
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
)
),
],
None,
),
],
)
def test_extract_tool_calls(
mistral_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def _test_extract_tool_calls_streaming(
tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content
):
other_content: str = ""
function_names: list[str] = []
function_args_strs: list[str] = []
tool_call_idx: int = -1
tool_call_ids: list[str | None] = []
for delta_message in stream_delta_message_generator(
tool_parser, tokenizer, model_output, tools
):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
streamed_tool_calls = delta_message.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
assert len(tool_parser.prev_tool_call_arr) > 0
# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
function_args_strs.append("")
tool_call_ids.append(None)
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id and not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
function_names.append(tool_call.function.name)
if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
function_args_strs[tool_call.index] += tool_call.function.arguments
assert other_content == expected_content
actual_tool_calls = [
ToolCall(
id=tool_call_id,
function=FunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR
),
),
)
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs
)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool_add",
"single_tool_add_strings",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
"multiple_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""This is a test""", [], """This is a test"""),
(
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": "3", "b": "4"})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming_pre_v11_tokenizer(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
_test_extract_tool_calls_streaming(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
None,
expected_tool_calls,
expected_content,
)
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_add_strings",
"multiple_tools",
],
argnames=["tools", "expected_tool_calls", "expected_content"],
argvalues=[
(
[("add", '{"a": 3, "b": 4}')],
# [TOOL_CALLS]add{"a": 3, "b": 4}
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
[("add_two_strings", '{"a": "3", "b": "4"}')],
# [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"}
[
ToolCall(
function=FunctionCall(
name="add_two_strings",
arguments=json.dumps({"a": "3", "b": "4"}),
)
)
],
"",
),
(
[
("add", '{"a": 3.5, "b": 4}'),
(
"get_current_weather",
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501
),
],
# [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming(
mistral_tool_parser,
mistral_tokenizer,
tools,
expected_tool_calls,
expected_content,
):
_test_extract_tool_calls_streaming(
mistral_tool_parser,
mistral_tokenizer,
None,
tools,
expected_tool_calls,
expected_content,
)
@pytest.mark.parametrize(
ids=[
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
"content_before_tool",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
"",
),
(
"""[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="multiply", arguments=json.dumps({"a": 3, "b": 6})
)
),
],
"",
),
(
# Additional content should not be after the tool calls
"""bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
"bla",
),
],
)
def test_extract_tool_calls_streaming_one_chunk(
mistral_tool_parser,
mistral_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
if isinstance(mistral_tokenizer, MistralTokenizer):
all_token_ids = mistral_tokenizer.encode(model_output)
else:
all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_tool_parser, mistral_tokenizer
)
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=model_output,
delta_text=model_output,
previous_token_ids=[],
current_token_ids=all_token_ids,
delta_token_ids=all_token_ids,
request=None,
) # type: ignore[arg-type]
assert isinstance(delta_message, DeltaMessage)
assert len(delta_message.tool_calls) == len(expected_tool_calls)
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
if delta_message.content is None:
assert expected_content == ""
else:
assert delta_message.content == expected_content
@pytest.mark.parametrize(
ids=[
"no_tools",
"single_tool_add",
"single_tool_add_strings",
"single_tool_weather",
"argument_before_name",
"argument_before_name_and_name_in_argument",
"multiple_tools",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""This is a test""", [], """This is a test"""),
(
"""[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3, "b": 4})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": "3", "b": "4"})
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="get_age",
arguments=json.dumps(
{
"name": "John Doe",
}
),
)
)
],
"",
),
(
"""[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501
[
ToolCall(
function=FunctionCall(
name="add", arguments=json.dumps({"a": 3.5, "b": 4})
)
),
ToolCall(
function=FunctionCall(
name="get_current_weather",
arguments=json.dumps(
{"city": "San Francisco", "state": "CA", "unit": "celsius"}
),
)
),
],
"",
),
],
)
def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
mistral_pre_v11_tool_parser,
mistral_pre_v11_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer):
all_token_ids = mistral_pre_v11_tokenizer.encode(model_output)
else:
all_token_ids = mistral_pre_v11_tokenizer.encode(
model_output, add_special_tokens=False
)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer
)
delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=model_output,
delta_text=model_output,
previous_token_ids=[],
current_token_ids=all_token_ids,
delta_token_ids=all_token_ids,
request=None,
) # type: ignore[arg-type]
assert isinstance(delta_message, DeltaMessage)
assert len(delta_message.tool_calls) == len(expected_tool_calls)
assert_tool_calls(delta_message.tool_calls, expected_tool_calls)
if delta_message.content is None:
assert expected_content == ""
else:
assert delta_message.content == expected_content

View File

@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = {
"supports_parallel": True,
"extended": True,
},
"mistral": {
"mistral-7b": {
"model": "mistralai/Mistral-7B-Instruct-v0.3",
"arguments": [
"--enforce-eager",
@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = {
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
"supports_parallel": True,
},
"mistral-small-3.2": {
"model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506",
"arguments": [
"--enforce-eager",
"--no-enable-prefix-caching",
"--tool-call-parser",
"mistral",
"--tokenizer-mode",
"mistral",
"--config-format",
"mistral",
"--load-format",
"mistral",
"--tensor-parallel-size",
"4",
'--ignore-patterns="consolidated.safetensors"',
],
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
" that you have would be helpful to answer a user query, "
"call the tool. Otherwise, answer the user's query directly "
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
"to the user's question - just respond to it normally.",
"supports_parallel": True,
"extended": True,
},
# FIXME: This test currently fails, need to debug why.
# "granite20b": {

View File

@ -11,7 +11,9 @@ PROMPTS = [
]
def test_reset_prefix_cache_e2e():
def test_reset_prefix_cache_e2e(monkeypatch):
# "spawn" is required for test to be deterministic
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
engine_args = EngineArgs(
model="Qwen/Qwen3-0.6B",
gpu_memory_utilization=0.2,

View File

@ -9,6 +9,7 @@ from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.hashing import _xxhash
def test_prefix_caching_from_cli():
@ -48,6 +49,21 @@ def test_prefix_caching_from_cli():
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed")
def test_prefix_caching_xxhash_from_cli():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
# set hash algorithm to xxhash (pickle)
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash"
# set hash algorithm to xxhash_cbor
args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor"
def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")
vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS)

View File

@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
import json
import openai # use the official client for correctness check
@ -13,6 +11,7 @@ from transformers import AutoConfig
from tests.conftest import ImageTestAssets
from tests.utils import RemoteOpenAIServer
from vllm.utils.serial_utils import tensor2base64
# any model with a chat template should work here
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
@ -50,18 +49,6 @@ async def client_with_image_embeds(server_with_image_embeds):
yield async_client
def encode_image_embedding_to_base64(image_embedding) -> str:
"""
Encode image embedding to base64 string
"""
buffer = io.BytesIO()
torch.save(image_embedding, buffer)
buffer.seek(0)
binary_data = buffer.read()
base64_image_embedding = base64.b64encode(binary_data).decode("utf-8")
return base64_image_embedding
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32])
@ -73,7 +60,7 @@ async def test_completions_with_image_embeds(
):
# Test case: Single image embeds input
image_embeds = image_assets[0].image_embeds.to(dtype=dtype)
base64_image_embedding = encode_image_embedding_to_base64(image_embeds)
base64_image_embedding = tensor2base64(image_embeds)
chat_completion = await client_with_image_embeds.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant."},

View File

@ -3,12 +3,14 @@
from dataclasses import asdict
from typing import NamedTuple
import pytest
from PIL import Image
from vllm import LLM, EngineArgs, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import KVTransferConfig
from vllm.multimodal.utils import encode_image_base64
from vllm.platforms import current_platform
MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8"
@ -108,6 +110,13 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]):
print("-" * 50)
@pytest.mark.skipif(
current_platform.is_rocm(),
reason=(
"hipErrorLaunchFailure when running this test, see issue:"
"https://github.com/ROCm/pytorch/issues/2822"
),
)
def test_shared_storage_connector_hashes(tmp_path):
"""
Tests that SharedStorageConnector saves KV to the storage locations

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import sys
from typing import Any
import pytest
@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test
from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS,
MODEL_NAME,
POOLING_MODEL_NAME,
@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import (
CustomLogitprocSource,
DummyLogitsProcessor,
WrappedPerReqLogitsProcessor,
dummy_module,
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource
kwargs: dict[str, list[str | type[LogitsProcessor]]] = {}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
# Scenario: load logitproc based on fully-qualified class name (FQCN)
# Inject dummy module which defines logitproc
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
# Scenario: load logitproc from provided class object

View File

@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te
from tests.v1.logits_processors.utils import (
DUMMY_LOGITPROC_ARG,
DUMMY_LOGITPROC_FQCN,
DUMMY_LOGITPROC_MODULE,
MAX_TOKENS,
MODEL_NAME,
TEMP_GREEDY,
dummy_module,
prompts,
)
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint(
main.main()
def _server_with_logitproc_module(
def _server_with_logitproc_fqcn(
env_dict: dict[str, str] | None,
model: str,
vllm_serve_args: list[str],
) -> None:
"""Start vLLM server, inject module with dummy logitproc"""
# Patch `modules` to inject dummy logitproc module
from vllm.entrypoints.cli import main
sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module
# fork is required for workers to see entrypoint patch
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork"
if env_dict is not None:
os.environ.update(env_dict)
@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch):
if request.param:
# Launch server, append FQCN argument, inject dummy logitproc module
args = default_server_args + request.param
_server_fxn = _server_with_logitproc_module
_server_fxn = _server_with_logitproc_fqcn
else:
# Launch server, inject dummy logitproc entrypoint
args = default_server_args

View File

@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token"
TEMP_GREEDY = 0.0
MAX_TOKENS = 20
DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc"
DUMMY_LOGITPROC_MODULE = "DummyModule"
DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils"
DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor"

View File

@ -5,6 +5,7 @@ import torch
from vllm.config import SpeculativeConfig
from vllm.model_executor.models.interfaces import supports_eagle3
from vllm.platforms import current_platform
@pytest.mark.parametrize(
@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier",
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="The tests are skipped on rocm platform.",
),
),
],
)

View File

@ -761,6 +761,10 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1
@pytest.mark.skipif(
current_platform.is_rocm(),
reason="Attention backend FLASHINFER is not supported on ROCm.",
)
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
"""
The GPU model runner creates different views into the

View File

@ -283,6 +283,28 @@ def _rocm_aiter_grouped_topk_fake(
pass
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None
def _check_aiter_mla_fp8_support() -> bool:
"""Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters."""
global _AITER_MLA_SUPPORTS_FP8
if _AITER_MLA_SUPPORTS_FP8 is None:
try:
import inspect
from aiter.mla import mla_decode_fwd
sig = inspect.signature(mla_decode_fwd)
_AITER_MLA_SUPPORTS_FP8 = (
"q_scale" in sig.parameters and "kv_scale" in sig.parameters
)
except Exception:
_AITER_MLA_SUPPORTS_FP8 = False
return _AITER_MLA_SUPPORTS_FP8
def _rocm_aiter_mla_decode_fwd_impl(
q: torch.Tensor,
kv_buffer: torch.Tensor,
@ -299,6 +321,16 @@ def _rocm_aiter_mla_decode_fwd_impl(
) -> None:
from aiter.mla import mla_decode_fwd
kwargs = {
"sm_scale": sm_scale,
"logit_cap": logit_cap,
}
# Only pass q_scale and kv_scale if the aiter library supports them
if _check_aiter_mla_fp8_support():
kwargs["q_scale"] = q_scale
kwargs["kv_scale"] = kv_scale
mla_decode_fwd(
q,
kv_buffer.view(-1, 1, 1, q.shape[-1]),
@ -308,10 +340,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
kv_indices,
kv_last_page_lens,
max_seqlen_qo,
sm_scale=sm_scale,
logit_cap=logit_cap,
q_scale=q_scale,
kv_scale=kv_scale,
**kwargs,
)

View File

@ -104,7 +104,8 @@ class FixFunctionalizationPass(VllmInductorPass):
mutated_args = {1: "result"}
self.defunctionalize(graph, node, mutated_args)
elif (
at_target
hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm")
and at_target
== torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default
):
mutated_args = {

View File

@ -30,7 +30,7 @@ CacheDType = Literal[
"fp8_ds_mla",
]
MambaDType = Literal["auto", "float32"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
KVOffloadingBackend = Literal["native", "lmcache"]
@ -77,9 +77,21 @@ class CacheConfig:
"""Whether to enable prefix caching."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256"
"""Set the hash algorithm for prefix caching:\n
- "sha256" uses Pickle for object serialization before hashing.\n
- "sha256" uses Pickle for object serialization before hashing. This is the
current default, as SHA256 is the most secure choice to avoid potential
hash collisions.\n
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
serializes objects using canonical CBOR and hashes them with SHA-256."""
serializes objects using canonical CBOR and hashes them with SHA-256.\n
- "xxhash" uses Pickle serialization with xxHash (128-bit) for faster,
non-cryptographic hashing. Requires the optional ``xxhash`` package.
IMPORTANT: Use of a hashing algorithm that is not considered
cryptographically secure theoretically increases the risk of hash collisions,
which can cause undefined behavior or even leak private information in
multi-tenant environments. Even if collisions are still very unlikely, it is
important to consider your security risk tolerance against the performance
benefits before turning this on.\n
- "xxhash_cbor" combines canonical CBOR serialization with xxHash for
reproducible hashing. Requires the optional ``xxhash`` package."""
cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to

View File

@ -4,7 +4,7 @@
import enum
from collections import Counter
from collections.abc import Callable
from dataclasses import asdict, field
from dataclasses import field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.config.utils import config, handle_deprecated
from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
@ -196,7 +196,16 @@ class PassConfig:
Any new fields that affect compilation should be added to the hash.
Any future fields that don't affect compilation should be excluded.
"""
return InductorPass.hash_dict(asdict(self))
ignored_fields = [
"enable_fusion",
"enable_attn_fusion",
"enable_noop",
"enable_sequence_parallelism",
"enable_async_tp",
"enable_fi_allreduce_fusion",
]
return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields))
@field_validator(
"fuse_norm_quant",
@ -267,14 +276,6 @@ class PassConfig:
"v0.13.0 or v1.0.0, whichever is sooner",
)
# Force old flags to None to ensure they are not used
self.enable_fusion = None
self.enable_attn_fusion = None
self.enable_noop = None
self.enable_sequence_parallelism = None
self.enable_async_tp = None
self.enable_fi_allreduce_fusion = None
if not self.eliminate_noops:
if self.fuse_norm_quant or self.fuse_act_quant:
logger.warning_once(

View File

@ -84,7 +84,7 @@ TaskOption = Literal[
"transcription",
"draft",
]
TokenizerMode = Literal["auto", "hf", "slow", "mistral"]
TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal[
"raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"
@ -141,6 +141,7 @@ class ModelConfig:
- "hf" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- Other custom values can be supported via plugins."""
trust_remote_code: bool = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model
@ -1779,20 +1780,22 @@ class ModelConfig:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support chunked prefill.",
pooling_type,
)
return False
else:
# pooling_type == "last"
elif pooling_type in ["all", "last"]:
logger.debug(
"Pooling models with causal attn and last pooling support "
"chunked prefill."
"Pooling models with causal attn and %s pooling support "
"chunked prefill.",
pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return attn_type != "encoder_decoder"
@ -1816,20 +1819,22 @@ class ModelConfig:
return False
elif attn_type == "decoder":
pooling_type = self.pooler_config.pooling_type.lower()
if pooling_type in ["all", "mean", "step", "cls"]:
if pooling_type in ["mean", "step", "cls"]:
logger.debug(
"Pooling models with %s pooling does not "
"support prefix caching.",
pooling_type,
)
return False
else:
# pooling_type == "last"
elif pooling_type in ["all", "last"]:
logger.debug(
"Pooling models with causal attn and last pooling support "
"prefix caching."
"Pooling models with causal attn and %s pooling support "
"prefix caching.",
pooling_type,
)
return True
else:
raise ValueError(f"{pooling_type=} not supported.")
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return False

View File

@ -593,10 +593,14 @@ class ParallelConfig:
"max_parallel_loading_workers is currently "
"not supported and will be ignored."
)
if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1:
allowed_backends = ("mp", "uni", "external_launcher")
if (
self.distributed_executor_backend not in allowed_backends
and self.nnodes > 1
):
raise ValueError(
"nnodes > 1 can only be set when distributed executor "
"backend is mp or uni."
"backend is mp, uni or external_launcher."
)
@property

View File

@ -190,3 +190,8 @@ KVConnectorFactory.register_connector(
"vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector",
"DecodeBenchConnector",
)
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector",
"MooncakeConnector",
)

View File

@ -4,13 +4,14 @@
KV cache helper for store.
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@ -21,89 +22,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
class model_aware_kv_ops_helper:
def __init__(self, config: VllmConfig):
self.is_deepseek_mla = config.model_config.is_deepseek_mla
self.use_mla_opt = not envs.VLLM_MLA_DISABLE
self.tp_size = config.parallel_config.tensor_parallel_size
def get_model_args(self, model_executable: torch.nn.Module):
model_config = model_executable.model.config
self.model_executable = model_executable
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
# Deepseek's MLA (Multi-head Latent Attention) uses two different
# kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0.
# When VLLM_MLA_DISABLE=0 (default), forward absorb is applied,
# resulting in a kv_cache shape of [num_blks, blk_size, 1,
# kv_lora_rank + qk_rope_head_dim].
# When VLLM_MLA_DISABLE=1, standard FA is used instead, leading
# to a kv_cache shape of [2, num_blks, blk_size,
# num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim].
# For more details, see vllm/v1/attention/backends/mla/common.py.
if self.is_deepseek_mla and self.use_mla_opt:
head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim
num_heads = 1
elif self.is_deepseek_mla and not self.use_mla_opt:
head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim
else:
head_size = getattr(model_config, "head_dim", None)
if head_size is None:
head_size = int(hidden_size // num_attention_heads)
return num_heads, head_size
def get_kv_from_cache(self, kv_cache, num_heads, head_size):
if self.is_deepseek_mla and self.use_mla_opt:
key_cache = kv_cache.reshape(-1, num_heads, head_size)
value_cache = kv_cache.reshape(-1, num_heads, head_size)
else:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
return key_cache, value_cache
def put_kv_to_cache(
self,
model_executable: torch.nn.Module,
keys,
values,
layer,
kv_cache,
slot_mapping,
start_pos,
end_pos,
):
model_config = model_executable.model.config
if self.is_deepseek_mla and self.use_mla_opt:
layer.self_attn.attn = layer.self_attn.mla_attn
k_c_normed_k_pe = keys.squeeze(1)
k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank]
k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :]
ops.concat_and_cache_mla(
k_c_normed.to(kv_cache.device),
k_pe.to(kv_cache.device),
kv_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
)
else:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys.to(key_cache.device),
values.to(value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
@ -266,3 +184,124 @@ def copy_kv_blocks(
src_tensor = src_kv_caches[layer_name]
dst_tensor = dst_kv_caches[layer_name]
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
@dataclass
class TpKVTopology:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank: int
remote_tp_size: dict[str, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: str
remote_block_size: dict[str, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: str,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: str) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: str) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
self,
remote_tp_size: int,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
def get_target_remote_rank_from_engine_id(
self,
remote_engine_id: str,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)

View File

@ -0,0 +1,914 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import threading
import time
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import msgspec
import numpy as np
import torch
import zmq
import zmq.asyncio
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run VLLM with MooncakeTransferEngine."
) from e
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
EngineId = str
ReqId = str
TRANS_DONE = b"trans_done"
TRANS_ERROR = b"trans_error"
logger = init_logger(__name__)
class MooncakeAgentMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
remote_hostname: str
remote_port: int
request_ids: list[ReqId]
kv_caches_base_addr: list[int]
block_ids: list[list[int]]
@dataclass
class RecvReqMeta:
local_block_ids: list[int]
remote_host: str
remote_port: int
@dataclass
class SendBlockMeta:
local_block_ids: list[int]
ready: threading.Event
expire_time: float = float("inf")
@dataclass
class SendReqMeta:
reqs: dict[ReqId, SendBlockMeta]
lock: threading.Lock
@dataclass
class FinishedSendReqSet:
set: set[ReqId]
lock: threading.Lock
@dataclass
class FinishedReceiveReqSet:
set: set[ReqId]
lock: asyncio.Lock
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
self.reqs_to_send: dict[ReqId, list[int]] = {}
def add_new_req(
self,
request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
load_remote_cache: bool = True,
):
if load_remote_cache:
self.reqs_to_recv[request_id] = RecvReqMeta(
local_block_ids=local_block_ids,
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
)
else:
self.reqs_to_send[request_id] = local_block_ids
class MooncakeConnector(KVConnectorBase_V1):
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: MooncakeConnectorScheduler | None = (
MooncakeConnectorScheduler(vllm_config, self.engine_id)
)
self.connector_worker: MooncakeConnectorWorker | None = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
)
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens
)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(
self, finished_req_ids: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeConnector does not do layerwise saving."""
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> None:
"""MooncakeConnector does not save explicitly."""
pass
def wait_for_save(self):
pass
class MooncakeConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.engine_id: EngineId = engine_id
self.side_channel_host = get_ip()
self.side_channel_port = get_mooncake_side_channel_port(vllm_config)
assert vllm_config.kv_transfer_config
self.kv_role = vllm_config.kv_transfer_config.kv_role
logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id)
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[ReqId, list[int]] = {}
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens,
params,
)
if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
if count > 0:
return count, True
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
):
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens,
params,
)
if not params:
return
if params.get("do_remote_prefill"):
assert self.kv_role != "kv_producer"
if all(p in params for p in ("remote_host", "remote_port")):
# If remote_blocks and num_external_tokens = 0, we have
# a full prefix cache hit on the D worker. We need to call
# send_notif in _read_blocks to free the memory on the P.
local_block_ids = (
blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
)
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (request, local_block_ids)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer",
params,
)
# Only trigger 1 KV transfer per request.
params["do_remote_prefill"] = False
elif params.get("do_remote_decode"):
# Add an empty list to worker to create event.
self._reqs_need_send[request.request_id] = []
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = MooncakeConnectorMetadata()
# Loop through scheduled reqs and convert to RecvReqMeta.
if self.kv_role != "kv_producer":
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
self._reqs_need_recv.clear()
if self.kv_role != "kv_consumer":
for req_id, block_ids in self._reqs_need_send.items():
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params={},
load_remote_cache=False,
)
self._reqs_need_send.clear()
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, dict[str, Any] | None]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeConnector request_finished, request_status=%s, "
"kv_transfer_params=%s",
request.status,
params,
)
if not params:
return False, None
if params.get("do_remote_prefill"):
# If do_remote_prefill is still True when the request is finished,
# update_state_after_alloc must not have been called (the request
# must have been aborted before it was scheduled).
# To avoid stranding the prefill blocks in the prefill instance,
# we must add empty block_ids to _reqs_need_recv so that our
# worker side will notify and free blocks in the prefill instance.
assert self.kv_role != "kv_producer"
self._reqs_need_recv[request.request_id] = (request, [])
params["do_remote_prefill"] = False
return False, None
if (
not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED
):
return False, None
assert self.kv_role != "kv_consumer"
# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = len(block_ids) > 0
if delay_free_blocks:
self._reqs_need_send[request.request_id] = block_ids
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
class MooncakeConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id)
self.vllm_config = vllm_config
self.engine = TransferEngine()
self.hostname = get_ip()
ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "")
if ret_value != 0:
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
self.rpc_port = self.engine.get_rpc_port()
logger.debug(
"Mooncake Transfer Engine initialized at %s:%d",
self.hostname,
self.rpc_port,
)
# Mooncake handshake port.
self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config)
self.engine_id: EngineId = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.num_blocks = 0
assert vllm_config.kv_transfer_config
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"num_workers", 10
)
self.kv_caches_base_addr: list[int] = []
self.device_kv_caches: dict[str, torch.Tensor] = {}
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
# For kv_both, we will act both prefiller and decoder.
if self.kv_role != "kv_consumer":
# Background thread for sending kvcaches to D.
self._mooncake_sender_t: threading.Thread | None = None
# Background thread for processing new sending requests.
self._sender_executor = ThreadPoolExecutor(
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
)
logger.debug(
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
)
if self.kv_role != "kv_producer":
self.receiver_loop = asyncio.new_event_loop()
self._mooncake_receiver_t = threading.Thread(
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
)
self._mooncake_receiver_t.start()
logger.debug("Mooncake Decoder: start receiver thread")
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
set(), threading.Lock()
)
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
set(), asyncio.Lock()
)
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.use_mla = self.model_config.use_mla
backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
use_mla=self.use_mla,
)
self.backend_name = backend.get_name()
self.kv_cache_layout = get_kv_cache_layout()
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
self.zmq_ctx = zmq.Context()
self.async_zmq_ctx = zmq.asyncio.Context()
self._encoder = msgspec.msgpack.Encoder()
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
def __del__(self):
self.shutdown()
def shutdown(self):
"""Cleanup background threads on destruction."""
self.zmq_ctx.term()
self.async_zmq_ctx.term()
if self.kv_role != "kv_consumer":
self._sender_executor.shutdown(wait=False)
if self._mooncake_sender_t:
self._mooncake_sender_t.join()
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
self._mooncake_receiver_t.join()
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()
def _mooncake_sender(
self, ready_event: threading.Event, base_port: int, tp_rank: int
):
"""
Background thread that listens for Mooncake requests, dispatches them
to a thread pool, and sends acknowledgments upon completion.
"""
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
poller = zmq.Poller()
poller.register(frontend, zmq.POLLIN)
poller.register(backend, zmq.POLLIN)
ready_event.set()
try:
while True:
sockets = dict(poller.poll())
if frontend in sockets:
identity, _, metadata_bytes = frontend.recv_multipart()
self._sender_executor.submit(
self._sender_worker,
identity,
metadata_bytes,
backend_path,
)
if backend in sockets:
identity, status = backend.recv_multipart()
frontend.send_multipart((identity, b"", status))
except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
except Exception as e:
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
finally:
frontend.close()
backend.close()
def _sender_worker(
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
):
status = TRANS_ERROR
try:
metadata = self._decoder.decode(metadata_bytes)
self.send_kv_to_decode(metadata)
status = TRANS_DONE
except Exception as e:
logger.error("Error processing Mooncake handshake: %s", e)
finally:
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
try:
pusher.send_multipart((identity, status))
except zmq.ZMQError as e:
logger.warning(
"Internal error, maybe the server is shutting down. Error: %s",
e,
)
finally:
pusher.close()
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
with self.reqs_need_send.lock:
for req_id in meta.request_ids:
send_meta = self.reqs_need_send.reqs.get(req_id)
if send_meta is None:
logger.warning("Request %s not found in reqs_need_send", req_id)
return
# Mark it as not expired. We will send it now.
send_meta.expire_time = float("inf")
send_reqs.append((req_id, send_meta))
self._send_blocks(send_reqs, meta)
with self.reqs_need_send.lock:
for req_id in meta.request_ids:
del self.reqs_need_send.reqs[req_id]
with self.finished_sending_reqs.lock:
self.finished_sending_reqs.set.update(meta.request_ids)
def _send_blocks(
self,
send_reqs: list[tuple[ReqId, SendBlockMeta]],
agent_meta: MooncakeAgentMetadata,
):
src_ptrs = []
dst_ptrs = []
lengths = []
local_base_addr = self.kv_caches_base_addr
remote_base_addr = agent_meta.kv_caches_base_addr
block_len = self.block_len
remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}"
assert len(send_reqs) == len(agent_meta.block_ids)
for (req_id, send_meta), remote_block_ids in zip(
send_reqs, agent_meta.block_ids
):
send_meta.ready.wait()
num_remote_blocks = len(remote_block_ids)
if num_remote_blocks == 0:
continue
local_block_ids = send_meta.local_block_ids
# Partial prefix cache hit: just read uncomputed blocks.
num_local_blocks = len(local_block_ids)
assert num_local_blocks >= num_remote_blocks
if num_local_blocks > num_remote_blocks:
local_block_ids = local_block_ids[-num_remote_blocks:]
# Group by indices
group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous(
local_block_ids, remote_block_ids
)
for local_layer_addr, remote_layer_addr in zip(
local_base_addr, remote_base_addr
):
for group_local_block_id, group_remote_block_id in zip(
group_local_block_ids, group_remote_block_ids
):
src_ptrs.append(
local_layer_addr + group_local_block_id[0] * block_len
)
dst_ptrs.append(
remote_layer_addr + group_remote_block_id[0] * block_len
)
lengths.append(block_len * len(group_local_block_id))
logger.debug(
"Sending kv_caches for request %s (%d blocks) to %s",
req_id,
num_remote_blocks,
remote_session,
)
start_time = time.perf_counter()
ret_value = self.engine.batch_transfer_sync_write(
remote_session, src_ptrs, dst_ptrs, lengths
)
if ret_value != 0:
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
logger.debug(
"Sending to %s done, took %s",
remote_session,
time.perf_counter() - start_time,
)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in mooncake."""
logger.info("Registering KV_Caches. use_mla: %s", self.use_mla)
kv_data_ptrs = []
kv_data_lens = []
seen_base_addresses = []
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
for layer_name, cache_or_caches in kv_caches.items():
logger.debug(
"registering layer %s with shape %s", layer_name, cache_or_caches.shape
)
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
continue
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.nbytes
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
kernel_block_size = cache.shape[-2 if self.use_mla else -3]
assert self.block_size == kernel_block_size
kv_data_ptrs.append(base_addr)
kv_data_lens.append(tensor_size_bytes)
self.kv_caches_base_addr = seen_base_addresses
ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens)
if ret_value != 0:
raise RuntimeError("Mooncake batch memory registration failed.")
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.device_kv_caches = kv_caches
logger.debug(
"registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len
)
# No need to launch server for D node.
if self.kv_role == "kv_consumer":
return
ready_event = threading.Event()
self._mooncake_sender_t = threading.Thread(
target=self._mooncake_sender,
args=(ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="mooncake_sender",
)
self._mooncake_sender_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
async with self.finished_recving_reqs.lock:
finished_recving_reqs = self.finished_recving_reqs.set
self.finished_recving_reqs.set = set()
return finished_recving_reqs
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
"""
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
fut = None
if self.kv_role != "kv_producer":
fut = asyncio.run_coroutine_threadsafe(
self.fetch_finished_recving_reqs(), self.receiver_loop
)
if self.kv_role != "kv_consumer":
with self.finished_sending_reqs.lock:
finished_sending_reqs = self.finished_sending_reqs.set
self.finished_sending_reqs.set = set()
else:
finished_sending_reqs = set()
finished_recving_reqs = fut.result() if fut else set()
if finished_sending_reqs or finished_recving_reqs:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving",
self.tp_rank,
len(finished_sending_reqs),
len(finished_recving_reqs),
)
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
with self.reqs_need_send.lock:
expired_reqs = [
req_id
for req_id, send_meta in self.reqs_need_send.reqs.items()
if send_meta.expire_time < now
]
for req_id in expired_reqs:
logger.warning(
"Request %s timed out after %d seconds without "
"being sent. Freeing its blocks on the producer side.",
req_id,
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
)
del self.reqs_need_send.reqs[req_id]
if expired_reqs:
finished_sending_reqs.update(expired_reqs)
return finished_sending_reqs or None, finished_recving_reqs or None
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
req_ids, block_ids = map(list, zip(*req_blocks))
metadata = MooncakeAgentMetadata(
remote_hostname=self.hostname,
remote_port=self.rpc_port,
request_ids=req_ids,
kv_caches_base_addr=self.kv_caches_base_addr,
block_ids=block_ids,
)
encoded_data = self._encoder.encode(metadata)
logger.debug(
"Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data)
)
logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path)
# Send query for the request.
sock: zmq.asyncio.Socket = make_zmq_socket(
self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0
)
sock.setsockopt(zmq.RCVTIMEO, 60000)
try:
await sock.send(encoded_data)
ret_msg = await sock.recv()
if ret_msg != TRANS_DONE:
logger.error(
"Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501
req_ids,
)
return
except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.")
except Exception as e:
logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e)
return
finally:
sock.close()
async with self.finished_recving_reqs.lock:
self.finished_recving_reqs.set.update(req_ids)
logger.debug("pulling kv_caches for %s finished", req_ids)
def group_kv_pull(self, metadata: MooncakeConnectorMetadata):
kv_pulls = defaultdict(list)
for req_id, meta in metadata.reqs_to_recv.items():
logger.debug(
"start_load_kv for request %s from remote engine. "
"Num local_block_ids: %s.",
req_id,
len(meta.local_block_ids),
)
path = make_zmq_path(
"tcp", meta.remote_host, meta.remote_port + self.tp_rank
)
kv_pulls[path].append((req_id, meta.local_block_ids))
return kv_pulls
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
if self.kv_role != "kv_producer":
kv_pulls = self.group_kv_pull(metadata)
for path, req_blocks in kv_pulls.items():
asyncio.run_coroutine_threadsafe(
self.receive_kv(path, req_blocks), self.receiver_loop
)
if self.kv_role != "kv_consumer":
with self.reqs_need_send.lock:
for req_id, block_ids in metadata.reqs_to_send.items():
if block_ids:
# Already gone through request_finished()
send_meta = self.reqs_need_send.reqs[req_id]
send_meta.local_block_ids = block_ids
send_meta.ready.set()
send_meta.expire_time = (
time.perf_counter()
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
else:
# From update_state_after_alloc(),
# but not reach request_finished() yet
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
local_block_ids=[], ready=threading.Event()
)
def group_concurrent_contiguous(
src_indices: list[int], dst_indices: list[int]
) -> tuple[list[list[int]], list[list[int]]]:
"""Vectorised NumPy implementation."""
if len(src_indices) == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
# This logic is now centralized
return (
envs.VLLM_MOONCAKE_BOOTSTRAP_PORT
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)

View File

@ -20,10 +20,10 @@ import torch
import zmq
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
KVConnectorBase_V1,
@ -668,128 +668,6 @@ class NixlConnectorScheduler:
class NixlConnectorWorker:
"""Implementation of Worker side methods"""
@dataclass
class TpKVTopology:
"""
Helper class for tensor parallel and KV topology information for
mapping between local and remote TP workers.
"""
tp_rank: int
remote_tp_size: dict[EngineId, int]
is_mla: bool
total_num_kv_heads: int
attn_backend: type[AttentionBackend]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
)
@property
def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio(
self,
remote_tp_size: int,
) -> int:
"""
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
"""
assert self.tp_size % remote_tp_size == 0, (
f"Local tensor parallel size {self.tp_size} is not divisible "
f"by remote tensor parallel size {remote_tp_size}."
)
return self.tp_size // remote_tp_size
def block_size_ratio(
self,
remote_block_size: int,
) -> float:
"""
Calculate the block size ratio between local and remote TP.
"""
assert self.block_size % remote_block_size == 0, (
f"Local block size {self.block_size} is not divisible "
f"by remote block size {remote_block_size} or vice versa."
)
return self.block_size // remote_block_size
def tp_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.tp_ratio(remote_tp_size)
def block_size_ratio_from_engine_id(
self,
remote_engine_id: EngineId,
) -> float:
remote_block_size = self.remote_block_size[remote_engine_id]
return self.block_size_ratio(remote_block_size)
def is_kv_replicated(self, engine_id: EngineId) -> bool:
"""
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
"""
tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split.
return self.is_mla or self.is_kv_replicated(remote_engine_id)
def get_target_remote_rank(
self,
remote_tp_size: int,
) -> int:
"""
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
return self.tp_rank // tp_ratio
def get_target_remote_rank_from_engine_id(
self,
remote_engine_id: EngineId,
) -> int:
remote_tp_size = self.remote_tp_size[remote_engine_id]
return self.get_target_remote_rank(remote_tp_size)
def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
@ -958,7 +836,7 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
self.kv_topo = self.TpKVTopology(
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state

View File

@ -1169,17 +1169,13 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
if config is not None and config.parallel_config.nnodes > 1:
parallel_config = config.parallel_config
ip = parallel_config.master_addr
rank = parallel_config.data_parallel_rank * world_size + rank
world_size = parallel_config.world_size_across_dp
port = parallel_config.master_port
distributed_init_method = get_distributed_init_method(ip, port)
elif (
if (
config is not None
and config.parallel_config.data_parallel_size > 1
and config.parallel_config.distributed_executor_backend != "external_launcher"
and (
config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1
)
):
parallel_config = config.parallel_config
# adjust to take into account data parallelism
@ -1187,15 +1183,22 @@ def init_distributed_environment(
rank = parallel_config.data_parallel_rank * world_size + rank
# adjust the world size to take into account data parallelism
world_size = parallel_config.world_size_across_dp
ip = parallel_config.data_parallel_master_ip
port = parallel_config.get_next_dp_init_port()
distributed_init_method = get_distributed_init_method(ip, port)
logger.debug(
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
world_size,
rank,
distributed_init_method,
)
# Use appropriate IP and port based on configuration
if parallel_config.nnodes > 1:
ip = parallel_config.master_addr
port = parallel_config.master_port
distributed_init_method = get_distributed_init_method(ip, port)
else:
ip = parallel_config.data_parallel_master_ip
port = parallel_config.get_next_dp_init_port()
distributed_init_method = get_distributed_init_method(ip, port)
logger.debug(
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
world_size,
rank,
distributed_init_method,
)
if not torch.distributed.is_initialized():
logger.info(
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",

View File

@ -183,7 +183,9 @@ class AnthropicServingMessages(OpenAIServingChat):
if anthropic_request.stream:
req.stream = anthropic_request.stream
req.stream_options = StreamOptions.validate({"include_usage": True})
req.stream_options = StreamOptions.validate(
{"include_usage": True, "continuous_usage_stats": True}
)
if anthropic_request.tool_choice is None:
req.tool_choice = None
@ -323,6 +325,12 @@ class AnthropicServingMessages(OpenAIServingChat):
content=[],
model=origin_chunk.model,
),
usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage
else 0,
output_tokens=0,
),
)
first_item = False
data = chunk.model_dump_json(exclude_unset=True)

View File

@ -536,7 +536,7 @@ def resolve_hf_chat_template(
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
@ -593,7 +593,7 @@ def resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
@ -627,11 +627,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt.
"""
def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike):
def __init__(self, model_config: ModelConfig):
super().__init__()
self._model_config = model_config
self._tokenizer = tokenizer
self._items_by_modality = defaultdict[str, list[_T | None]](list)
self._uuids_by_modality = defaultdict[str, list[str | None]](list)
@ -1139,11 +1138,19 @@ def validate_chat_template(chat_template: Path | str | None):
not any(c in chat_template for c in JINJA_CHARS)
and not Path(chat_template).exists()
):
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!"
# Try to find the template in the built-in templates directory
from vllm.transformers_utils.chat_templates.registry import (
CHAT_TEMPLATES_DIR,
)
builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
if not builtin_template_path.exists():
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist! "
f"Tried: {chat_template} and {builtin_template_path}"
)
else:
raise TypeError(f"{type(chat_template)} is not a valid chat template type")
@ -1173,12 +1180,23 @@ def _load_chat_template(
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (
f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}"
# Try to load from the built-in templates directory
from vllm.transformers_utils.chat_templates.registry import (
CHAT_TEMPLATES_DIR,
)
raise ValueError(msg) from e
builtin_template_path = CHAT_TEMPLATES_DIR / chat_template
try:
with open(builtin_template_path) as f:
return f.read()
except OSError:
msg = (
f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be opened. "
f"Tried: {chat_template} and {builtin_template_path}. "
f"Reason: {e}"
)
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
@ -1593,7 +1611,6 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
@ -1601,7 +1618,7 @@ def parse_chat_messages(
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
mm_tracker = MultiModalItemTracker(model_config)
for msg in messages:
sub_messages = _parse_chat_message_content(
@ -1625,7 +1642,6 @@ def parse_chat_messages(
def parse_chat_messages_futures(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: TokenizerLike,
content_format: _ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
@ -1633,7 +1649,7 @@ def parse_chat_messages_futures(
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
mm_tracker = AsyncMultiModalItemTracker(model_config)
for msg in messages:
sub_messages = _parse_chat_message_content(

View File

@ -328,6 +328,105 @@ def render_for_completion(messages: list[Message]) -> list[int]:
return token_ids
def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutputItem:
"""Parse browser tool calls (search, open, find) into web search items."""
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
# Parse JSON args (with retry detection)
try:
browser_call = json.loads(content.text)
except json.JSONDecodeError:
json_retry_output_message = (
f"Invalid JSON args, caught and retried: {content.text}"
)
browser_call = {
"query": json_retry_output_message,
"url": json_retry_output_message,
"pattern": json_retry_output_message,
}
# Create appropriate action based on recipient
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search"
)
elif recipient == "browser.open":
action = ActionOpenPage(
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
)
elif recipient == "browser.find":
action = ActionFind(
pattern=browser_call.get("pattern", ""),
url=f"cursor:{browser_call.get('url', '')}",
type="find",
)
else:
raise ValueError(f"Unknown browser action: {recipient}")
return ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
"""Parse function calls into function tool call items."""
function_name = recipient.split(".")[-1]
output_items = []
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"fc_{random_id}",
)
output_items.append(response_item)
return output_items
def _parse_reasoning_content(message: Message) -> list[ResponseOutputItem]:
"""Parse reasoning/analysis content into reasoning items."""
output_items = []
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(text=content.text, type="reasoning_text")
],
status=None,
)
output_items.append(reasoning_item)
return output_items
def _parse_final_message(message: Message) -> ResponseOutputItem:
"""Parse final channel messages into output message items."""
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
return ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
def parse_output_message(message: Message) -> list[ResponseOutputItem]:
"""
Parse a Harmony message into a list of output response items.
@ -340,119 +439,38 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
output_items: list[ResponseOutputItem] = []
recipient = message.recipient
# Browser tool calls
if recipient is not None and recipient.startswith("browser."):
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
# We do not need to check the VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY
# env variable since if it is not set, we are certain the json is valid
# The use of Actions for web search will be removed entirely in
# the future, so this is only necessary temporarily
try:
browser_call = json.loads(content.text)
except json.JSONDecodeError:
# If the content is not valid JSON, then it was
# caught and retried by vLLM, which means we
# need to make note of that so the user is aware
json_retry_output_message = (
f"Invalid JSON args, caught and retried: {content.text}"
)
browser_call = {
"query": json_retry_output_message,
"url": json_retry_output_message,
"pattern": json_retry_output_message,
}
# TODO: translate to url properly!
if recipient == "browser.search":
action = ActionSearch(
query=f"cursor:{browser_call.get('query', '')}", type="search"
)
elif recipient == "browser.open":
action = ActionOpenPage(
url=f"cursor:{browser_call.get('url', '')}", type="open_page"
)
elif recipient == "browser.find":
action = ActionFind(
pattern=browser_call["pattern"],
url=f"cursor:{browser_call.get('url', '')}",
type="find",
)
else:
raise ValueError(f"Unknown browser action: {recipient}")
web_search_item = ResponseFunctionWebSearch(
id=f"ws_{random_uuid()}",
action=action,
status="completed",
type="web_search_call",
)
output_items.append(web_search_item)
output_items.append(_parse_browser_tool_call(message, recipient))
# Analysis channel (reasoning/chain-of-thought)
elif message.channel == "analysis":
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=content.text, type="reasoning_text"
)
],
status=None,
)
output_items.append(reasoning_item)
output_items.extend(_parse_reasoning_content(message))
# Commentary channel
elif message.channel == "commentary":
# Function calls
if recipient is not None and recipient.startswith("functions."):
function_name = recipient.split(".")[-1]
for content in message.content:
random_id = random_uuid()
response_item = ResponseFunctionToolCall(
arguments=content.text,
call_id=f"call_{random_id}",
type="function_call",
name=function_name,
id=f"fc_{random_id}",
)
output_items.append(response_item)
output_items.extend(_parse_function_call(message, recipient))
# Built-in tools on commentary channel are treated as reasoning for now
elif recipient is not None and (
recipient.startswith("python")
or recipient.startswith("browser")
or recipient.startswith("container")
):
for content in message.content:
reasoning_item = ResponseReasoningItem(
id=f"rs_{random_uuid()}",
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(
text=content.text, type="reasoning_text"
)
],
status=None,
)
output_items.append(reasoning_item)
output_items.extend(_parse_reasoning_content(message))
else:
raise ValueError(f"Unknown recipient: {recipient}")
# Final output message
elif message.channel == "final":
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
)
contents.append(output_text)
text_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}",
content=contents,
role=message.author.role,
status="completed",
type="message",
)
output_items.append(text_item)
output_items.append(_parse_final_message(message))
else:
raise ValueError(f"Unknown channel: {message.channel}")
return output_items

View File

@ -834,7 +834,6 @@ class LLM:
conversation, mm_data, mm_uuids = parse_chat_messages(
msgs,
model_config,
tokenizer,
content_format=resolved_content_format,
)

View File

@ -105,7 +105,7 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
@ -1088,11 +1088,6 @@ class OpenAIServing:
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
if tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
@ -1105,7 +1100,6 @@ class OpenAIServing:
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
messages,
model_config,
tokenizer,
content_format=resolved_content_format,
)
@ -1128,6 +1122,13 @@ class OpenAIServing:
messages=messages,
**_chat_template_kwargs,
)
elif isinstance(tokenizer, DeepseekV32Tokenizer):
request_prompt = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
model_config=model_config,
**_chat_template_kwargs,
)
else:
request_prompt = apply_hf_chat_template(
tokenizer=tokenizer,

View File

@ -30,6 +30,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"deepseekv31_tool_parser",
"DeepSeekV31ToolParser",
),
"deepseek_v32": (
"deepseekv32_tool_parser",
"DeepSeekV32ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",

Some files were not shown because too many files have changed in this diff Show More