diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py deleted file mode 100644 index bbed80ebe8476..0000000000000 --- a/.buildkite/generate_index.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import argparse -import os - -template = """ - - -

Links for vLLM

- {x86_wheel}
- {arm_wheel}
- - -""" - -parser = argparse.ArgumentParser() -parser.add_argument("--wheel", help="The wheel path.", required=True) -args = parser.parse_args() - -filename = os.path.basename(args.wheel) - -with open("index.html", "w") as f: - print(f"Generated index.html for {args.wheel}") - # sync the abi tag with .buildkite/scripts/upload-wheels.sh - if "x86_64" in filename: - x86_wheel = filename - arm_wheel = filename.replace("x86_64", "aarch64").replace( - "manylinux1", "manylinux2014" - ) - elif "aarch64" in filename: - x86_wheel = filename.replace("aarch64", "x86_64").replace( - "manylinux2014", "manylinux1" - ) - arm_wheel = filename - else: - raise ValueError(f"Unsupported wheel: {filename}") - # cloudfront requires escaping the '+' character - f.write( - template.format( - x86_wheel=x86_wheel, - x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), - arm_wheel=arm_wheel, - arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), - ) - ) diff --git a/.buildkite/scripts/generate-nightly-index.py b/.buildkite/scripts/generate-nightly-index.py index 8d09ba178db7b..4d28ec9619de9 100644 --- a/.buildkite/scripts/generate-nightly-index.py +++ b/.buildkite/scripts/generate-nightly-index.py @@ -7,13 +7,14 @@ import argparse import json -import re import sys from dataclasses import asdict, dataclass from pathlib import Path from typing import Any from urllib.parse import quote +import regex as re + if not sys.version_info >= (3, 12): raise RuntimeError("This script requires Python 3.12 or higher.") diff --git a/.buildkite/scripts/hardware_ci/run-npu-test.sh b/.buildkite/scripts/hardware_ci/run-npu-test.sh index 29c8f5ed5a91a..0db1abe37ba11 100644 --- a/.buildkite/scripts/hardware_ci/run-npu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-npu-test.sh @@ -74,6 +74,7 @@ FROM ${BASE_IMAGE_NAME} # Define environments ENV DEBIAN_FRONTEND=noninteractive +ENV SOC_VERSION="ascend910b1" RUN pip config set global.index-url http://cache-service-vllm.nginx-pypi-cache.svc.cluster.local:${PYPI_CACHE_PORT}/pypi/simple && \ pip config set global.trusted-host cache-service-vllm.nginx-pypi-cache.svc.cluster.local && \ diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 2eaa91c04086c..0ac8fdd45bd0a 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -81,7 +81,7 @@ else alias_arg="" fi -$PYTHON .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" $alias_arg +$PYTHON pip install regex && .buildkite/scripts/generate-nightly-index.py --version "$SUBPATH" --current-objects "$obj_json" --output-dir "$INDICES_OUTPUT_DIR" $alias_arg # copy indices to // unconditionally echo "Uploading indices to $S3_COMMIT_PREFIX" diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index ee4fdebae5675..022b6ea236d54 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -987,7 +987,8 @@ steps: commands: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-mm-small.txt --tp-size=1 -- label: Multi-Modal Models Test (Extended) 1 +- label: Multi-Modal Models Test (Extended) 1 # 60min + timeout_in_minutes: 120 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking @@ -1011,7 +1012,8 @@ steps: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' -- label: Multi-Modal Models Test (Extended) 3 +- label: Multi-Modal Models Test (Extended) 3 # 75min + timeout_in_minutes: 150 mirror_hardwares: [amdexperimental] agent_pool: mi325_1 # grade: Blocking diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f79e9266559f6..a79f0b0c6bbdf 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -387,6 +387,7 @@ steps: working_dir: "/vllm-workspace/examples" source_file_dependencies: - vllm/entrypoints + - vllm/multimodal - examples/ commands: - pip install tensorizer # for tensorizer test diff --git a/README.md b/README.md index abbb63158f166..5c040fe4a66d2 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,7 @@ Compute Resources: - Alibaba Cloud - AMD - Anyscale +- Arm - AWS - Crusoe Cloud - Databricks diff --git a/benchmarks/benchmark_hash.py b/benchmarks/benchmark_hash.py new file mode 100644 index 0000000000000..08cdc012d6527 --- /dev/null +++ b/benchmarks/benchmark_hash.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Micro benchmark comparing built-in hash(), SHA-256, and xxHash. + +This focuses on a single test payload shaped like the prefix-cache hash input: + (32-byte bytes object, 32-int tuple) + +Usage: + python benchmarks/hash_micro_benchmark.py --iterations 20000 +""" + +from __future__ import annotations + +import argparse +import random +import statistics +import time +from collections.abc import Callable, Iterable + +from vllm.utils.hashing import sha256, xxhash + + +def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]: + """Generate a deterministic test payload.""" + random.seed(seed) + bytes_data = bytes(random.getrandbits(8) for _ in range(32)) + int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32)) + return (bytes_data, int_tuple) + + +def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int): + """Return (avg_seconds, std_seconds) for hashing `data` `iterations` times.""" + times: list[float] = [] + + # Warm-up to avoid first-run noise. + for _ in range(200): + func(data) + + for _ in range(iterations): + start = time.perf_counter() + func(data) + end = time.perf_counter() + times.append(end - start) + + avg = statistics.mean(times) + std = statistics.stdev(times) if len(times) > 1 else 0.0 + return avg, std + + +def _run_benchmarks( + benchmarks: Iterable[tuple[str, Callable[[tuple], object]]], + data: tuple, + iterations: int, +): + """Yield (name, avg, std) for each benchmark, skipping unavailable ones.""" + for name, func in benchmarks: + try: + avg, std = _benchmark_func(func, data, iterations) + except ModuleNotFoundError as exc: + print(f"Skipping {name}: {exc}") + continue + yield name, avg, std + + +def builtin_hash(data: tuple) -> int: + """Wrapper for Python's built-in hash().""" + return hash(data) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--iterations", + type=int, + default=10_000, + help="Number of measured iterations per hash function.", + ) + parser.add_argument( + "--seed", type=int, default=42, help="Random seed for test payload." + ) + args = parser.parse_args() + + data = _generate_test_data(args.seed) + benchmarks = ( + ("SHA256 (pickle)", sha256), + ("xxHash (pickle)", xxhash), + ("built-in hash()", builtin_hash), + ) + + print("=" * 60) + print("HASH FUNCTION MICRO BENCHMARK") + print("=" * 60) + print("Test data: (32-byte bytes object, 32-int tuple)") + print(f"Iterations: {args.iterations:,}") + print("=" * 60) + + results = list(_run_benchmarks(benchmarks, data, args.iterations)) + builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None) + + print("\nResults:") + for name, avg, std in results: + print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs") + + if builtin_entry: + _, builtin_avg, _ = builtin_entry + print("\n" + "=" * 60) + print("SUMMARY (relative to built-in hash())") + print("=" * 60) + for name, avg, _ in results: + if name == "built-in hash()": + continue + speed_ratio = avg / builtin_avg + print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()") + else: + print("\nBuilt-in hash() result missing; cannot compute speed ratios.") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/benchmark_prefix_block_hash.py b/benchmarks/benchmark_prefix_block_hash.py new file mode 100644 index 0000000000000..8bcd8af0d3102 --- /dev/null +++ b/benchmarks/benchmark_prefix_block_hash.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Simple benchmark to compare prefix-cache block hashing algorithms. + +Example: + python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 +""" + +from __future__ import annotations + +import argparse +import random +import statistics +import sys +import time +from collections.abc import Callable, Iterable, Sequence + +from vllm.utils.hashing import get_hash_fn_by_name +from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash + +SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor") + + +def _generate_blocks( + num_blocks: int, block_size: int, vocab_size: int, seed: int +) -> list[list[int]]: + rng = random.Random(seed) + return [ + [rng.randrange(vocab_size) for _ in range(block_size)] + for _ in range(num_blocks) + ] + + +def _hash_all_blocks( + hash_fn: Callable[[object], bytes], + blocks: Iterable[Sequence[int]], +) -> float: + parent_hash: BlockHash | None = None + start = time.perf_counter() + for block in blocks: + parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None) + end = time.perf_counter() + return end - start + + +def _benchmark( + hash_algo: str, + blocks: list[list[int]], + trials: int, +) -> tuple[float, float, float] | None: + try: + hash_fn = get_hash_fn_by_name(hash_algo) + init_none_hash(hash_fn) + timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)] + except ModuleNotFoundError as exc: + print(f"Skipping {hash_algo}: {exc}", file=sys.stderr) + return None + + avg = statistics.mean(timings) + best = min(timings) + # throughput: tokens / second + tokens_hashed = len(blocks) * len(blocks[0]) + throughput = tokens_hashed / best + return avg, best, throughput + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.") + parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.") + parser.add_argument( + "--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)." + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed.") + parser.add_argument( + "--trials", type=int, default=5, help="Number of timed trials per algorithm." + ) + parser.add_argument( + "--algorithms", + nargs="+", + default=SUPPORTED_ALGOS, + choices=SUPPORTED_ALGOS, + help="Hash algorithms to benchmark.", + ) + args = parser.parse_args() + + blocks = _generate_blocks( + args.num_blocks, args.block_size, args.vocab_size, args.seed + ) + print( + f"Benchmarking {len(args.algorithms)} algorithms on " + f"{args.num_blocks} blocks (block size={args.block_size})." + ) + + for algo in args.algorithms: + result = _benchmark(algo, blocks, args.trials) + if result is None: + continue + + avg, best, throughput = result + print( + f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s " + f"throughput: {throughput / 1e6:.2f}M tokens/s" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py new file mode 100644 index 0000000000000..04921dafbdbea --- /dev/null +++ b/benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py @@ -0,0 +1,244 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from enum import Enum +from itertools import product +from typing import Any + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor, + silu_mul_per_token_group_quant_fp8_colmajor, +) +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + +from .utils import ArgPool, Bench, CudaGraphBenchParams + +GROUP_SIZE = 128 +FLOAT8_T = torch.float8_e4m3fn + + +def print_timers(timers: list[TMeasurement], cuda_graph_nops: int): + print( + f"Note : The timings reported above is for {cuda_graph_nops} " + "consecutive invocations of the benchmarking functions. " + f"Please divide by {cuda_graph_nops} for single invocation " + "timings." + ) + compare = TBenchmark.Compare(timers) + compare.print() + + +class ImplType(Enum): + SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1 + REFERENCE = 2 + + def get_impl(self): + if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return silu_mul_per_token_group_quant_fp8_colmajor + elif self == ImplType.REFERENCE: + return reference + raise ValueError(f"Unrecognized ImplType {self}") + + +@dataclass +class BenchmarkTensors: + input: torch.Tensor + output: torch.Tensor + + # Reference act output tensor + ref_act_out: torch.Tensor + ref_quant_out: torch.Tensor + + @staticmethod + def make(T: int, N: int) -> "BenchmarkTensors": + assert T % GROUP_SIZE == 0 + assert N % (GROUP_SIZE * 2) == 0 + + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") + + # silu_mul_per_token_group_quant_fp8_colmajor output. + output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to( + FLOAT8_T + ) + + # reference output. + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") + ref_quant_out = torch.empty( + (T, N // 2), dtype=torch.bfloat16, device="cuda" + ).to(FLOAT8_T) + + return BenchmarkTensors( + input=input, + output=output, + ref_act_out=ref_act_out, + ref_quant_out=ref_quant_out, + ) + + @property + def T(self): + return self.input.size(0) + + @property + def N(self): + return self.input.size(1) + + def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]: + if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR: + return { + "input": self.input, + "output": self.output, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + elif impl_type == ImplType.REFERENCE: + return { + "input": self.input, + "act_out": self.ref_act_out, + "quant_out": self.ref_quant_out, + "use_ue8m0": is_deep_gemm_e8m0_used(), + } + raise ValueError(f"Unrecognized impl_type {impl_type}") + + +def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool): + """ + Reference triton quant kernel from, + vllm.model_executor.layers.quantization.utils.fp8_utils + """ + assert quant_out.size() == x.size() + # Allocate the scale tensor column-major format. + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] + x_q = quant_out + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + + M = x.numel() // GROUP_SIZE + N = GROUP_SIZE + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + finfo = torch.finfo(FLOAT8_T) + fp8_min = finfo.min + fp8_max = finfo.max + + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + GROUP_SIZE, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + return x_q, x_s + + +def reference( + input: torch.Tensor, + act_out: torch.Tensor, + quant_out: torch.Tensor, + use_ue8m0: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + torch.ops._C.silu_and_mul(act_out, input) + return reference_quant(act_out, quant_out, use_ue8m0) + + +def bench_impl( + bench_tensors: list[BenchmarkTensors], impl_type: ImplType +) -> TMeasurement: + T = bench_tensors[0].T + N = bench_tensors[0].N + + arg_pool_size = len(bench_tensors) + kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors] + + # warmup + for kwargs in kwargs_list: + impl_type.get_impl()(**kwargs) + torch.cuda.synchronize() + + # Merge into a single kwargs and qualify arguments as ArgPool + kwargs = {k: ArgPool([]) for k in kwargs_list[0]} + for _kwargs in kwargs_list: + for k, v in _kwargs.items(): + kwargs[k].values.append(v) + + cuda_graph_params = None + cuda_graph_params = CudaGraphBenchParams(arg_pool_size) + timer = None + with Bench( + cuda_graph_params, + "silu-mul-quant", + f"num_tokens={T}, N={N}", + impl_type.name, + impl_type.get_impl(), + **kwargs, + ) as bench: + timer = bench.run() + return timer + + +def test_correctness(T: int, N: int): + print(f"Testing num_tokens={T}, N={N} ...") + + bench_tensor = BenchmarkTensors.make(T, N) + + def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]: + return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl)) + + # reference output + ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE) + + # test ouptut + out_q, out_s = output_from_impl( + ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + + torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32)) + torch.testing.assert_close(ref_out_s, out_s) + + +def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]: + timers = [] + for N, T in product(Ns, Ts): + test_correctness(T, N) + + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(T, N) for _ in range(arg_pool_size) + ] + + silu_mul_quant_timer = bench_impl( + bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR + ) + timers.append(silu_mul_quant_timer) + reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE) + timers.append(reference_timer) + + print_timers( + [silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size + ) + + print_timers(timers, cuda_graph_nops=arg_pool_size) + + return timers + + +if __name__ == "__main__": + T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)] + N = [2048, 4096, 8192] + + print(f"T = {T}, N = {N}") + run(T, N, arg_pool_size=8) diff --git a/docker/Dockerfile b/docker/Dockerfile index 006481b23cb9f..0d50d97e54c6c 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -150,6 +150,97 @@ ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} #################### BASE BUILD IMAGE #################### +#################### CSRC BUILD IMAGE #################### +FROM base AS csrc-build +ARG TARGETPLATFORM + +ARG PIP_INDEX_URL UV_INDEX_URL +ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL +ARG PYTORCH_CUDA_INDEX_BASE_URL + +# install build dependencies +COPY requirements/build.txt requirements/build.txt + +# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out +# Reference: https://github.com/astral-sh/uv/pull/1694 +ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" +# Use copy mode to avoid hardlink failures with Docker cache mounts +ENV UV_LINK_MODE=copy + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \ + --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') + +WORKDIR /workspace + +COPY pyproject.toml setup.py CMakeLists.txt ./ +COPY cmake cmake/ +COPY csrc csrc/ +COPY vllm/envs.py vllm/envs.py +COPY vllm/__init__.py vllm/__init__.py + +# max jobs used by Ninja to build extensions +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} +# number of threads used by nvcc +ARG nvcc_threads=8 +ENV NVCC_THREADS=$nvcc_threads + +ARG USE_SCCACHE +ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz +ARG SCCACHE_ENDPOINT +ARG SCCACHE_BUCKET_NAME=vllm-build-sccache +ARG SCCACHE_REGION_NAME=us-west-2 +ARG SCCACHE_S3_NO_CREDENTIALS=0 + +# Flag to control whether to use pre-built vLLM wheels +ARG VLLM_USE_PRECOMPILED="" +ARG VLLM_MERGE_BASE_COMMIT="" +ARG VLLM_MAIN_CUDA_VERSION="" + +# Use dummy version for csrc-build wheel (only .so files are extracted, version doesn't matter) +ENV SETUPTOOLS_SCM_PRETEND_VERSION="0.0.0+csrc.build" + +# if USE_SCCACHE is set, use sccache to speed up compilation +RUN --mount=type=cache,target=/root/.cache/uv \ + if [ "$USE_SCCACHE" = "1" ]; then \ + echo "Installing sccache..." \ + && curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \ + && tar -xzf sccache.tar.gz \ + && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \ + && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ + && if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \ + && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \ + && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \ + && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ + && export SCCACHE_IDLE_TIMEOUT=0 \ + && export CMAKE_BUILD_TYPE=Release \ + && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \ + && export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" \ + && export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \ + && export VLLM_DOCKER_BUILD_CONTEXT=1 \ + && sccache --show-stats \ + && python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \ + && sccache --show-stats; \ + fi + +ARG vllm_target_device="cuda" +ENV VLLM_TARGET_DEVICE=${vllm_target_device} +ENV CCACHE_DIR=/root/.cache/ccache +RUN --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=cache,target=/root/.cache/uv \ + if [ "$USE_SCCACHE" != "1" ]; then \ + # Clean any existing CMake artifacts + rm -rf .deps && \ + mkdir -p .deps && \ + export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \ + export VLLM_PRECOMPILED_WHEEL_COMMIT="${VLLM_MERGE_BASE_COMMIT}" && \ + export VLLM_DOCKER_BUILD_CONTEXT=1 && \ + python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ + fi +#################### CSRC BUILD IMAGE #################### + #################### WHEEL BUILD IMAGE #################### FROM base AS build ARG TARGETPLATFORM @@ -172,66 +263,28 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') +WORKDIR /workspace + +COPY --from=csrc-build /workspace/dist /precompiled-wheels + COPY . . + ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi -# max jobs used by Ninja to build extensions -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} -# number of threads used by nvcc -ARG nvcc_threads=8 -ENV NVCC_THREADS=$nvcc_threads - -ARG USE_SCCACHE -ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz -ARG SCCACHE_ENDPOINT -ARG SCCACHE_BUCKET_NAME=vllm-build-sccache -ARG SCCACHE_REGION_NAME=us-west-2 -ARG SCCACHE_S3_NO_CREDENTIALS=0 - -# Flag to control whether to use pre-built vLLM wheels -ARG VLLM_USE_PRECOMPILED="" -ARG VLLM_MAIN_CUDA_VERSION="" - -# if USE_SCCACHE is set, use sccache to speed up compilation -RUN --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,source=.git,target=.git \ - if [ "$USE_SCCACHE" = "1" ]; then \ - echo "Installing sccache..." \ - && curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \ - && tar -xzf sccache.tar.gz \ - && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \ - && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \ - && if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \ - && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \ - && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \ - && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \ - && export SCCACHE_IDLE_TIMEOUT=0 \ - && export CMAKE_BUILD_TYPE=Release \ - && export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" \ - && export VLLM_MAIN_CUDA_VERSION="${VLLM_MAIN_CUDA_VERSION}" \ - && export VLLM_DOCKER_BUILD_CONTEXT=1 \ - && sccache --show-stats \ - && python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 \ - && sccache --show-stats; \ - fi - ARG vllm_target_device="cuda" ENV VLLM_TARGET_DEVICE=${vllm_target_device} -ENV CCACHE_DIR=/root/.cache/ccache -RUN --mount=type=cache,target=/root/.cache/ccache \ - --mount=type=cache,target=/root/.cache/uv \ - --mount=type=bind,source=.git,target=.git \ - if [ "$USE_SCCACHE" != "1" ]; then \ - # Clean any existing CMake artifacts - rm -rf .deps && \ - mkdir -p .deps && \ - export VLLM_USE_PRECOMPILED="${VLLM_USE_PRECOMPILED}" && \ - export VLLM_DOCKER_BUILD_CONTEXT=1 && \ - python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ - fi + +# Skip adding +precompiled suffix to version (preserves git-derived version) +ENV VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX=1 + +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=.git,target=.git \ + if [ "${vllm_target_device}" = "cuda" ]; then \ + export VLLM_PRECOMPILED_WHEEL_LOCATION=$(ls /precompiled-wheels/*.whl); \ + fi && \ + python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 # Install DeepGEMM from source ARG DEEPGEMM_GIT_REF @@ -527,7 +580,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ else \ BITSANDBYTES_VERSION="0.46.1"; \ fi; \ - uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.0' + uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.3' ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docs/assets/contributing/dockerfile-stages-dependency.png b/docs/assets/contributing/dockerfile-stages-dependency.png index b327eb2151f50..7420ca4d89441 100644 Binary files a/docs/assets/contributing/dockerfile-stages-dependency.png and b/docs/assets/contributing/dockerfile-stages-dependency.png differ diff --git a/docs/benchmarking/cli.md b/docs/benchmarking/cli.md index 44a4c40125952..1ce6b611745b1 100644 --- a/docs/benchmarking/cli.md +++ b/docs/benchmarking/cli.md @@ -670,6 +670,35 @@ vllm bench serve \ +### 🧪 Hashing Benchmarks + +
+Show more + +Two helper scripts live in `benchmarks/` to compare hashing options used by prefix caching and related utilities. They are standalone (no server required) and help choose a hash algorithm before enabling prefix caching in production. + +- `benchmarks/benchmark_hash.py`: Micro-benchmark that measures per-call latency of three implementations on a representative `(bytes, tuple[int])` payload. + +```bash +python benchmarks/benchmark_hash.py --iterations 20000 --seed 42 +``` + +- `benchmarks/benchmark_prefix_block_hash.py`: End-to-end block hashing benchmark that runs the full prefix-cache hash pipeline (`hash_block_tokens`) across many fake blocks and reports throughput. + +```bash +python benchmarks/benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32 --trials 5 +``` + +Supported algorithms: `sha256`, `sha256_cbor`, `xxhash`, `xxhash_cbor`. Install optional deps to exercise all variants: + +```bash +uv pip install xxhash cbor2 +``` + +If an algorithm’s dependency is missing, the script will skip it and continue. + +
+ ### ⚡ Request Prioritization Benchmark
diff --git a/docs/community/sponsors.md b/docs/community/sponsors.md index 8abb07caaab62..fd1c82376d086 100644 --- a/docs/community/sponsors.md +++ b/docs/community/sponsors.md @@ -18,6 +18,7 @@ Compute Resources: - Alibaba Cloud - AMD - Anyscale +- Arm - AWS - Crusoe Cloud - Databricks diff --git a/docs/design/metrics.md b/docs/design/metrics.md index 59cb6ba46fe17..13264f6861b0c 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -57,15 +57,15 @@ vLLM also provides [a reference example](../../examples/online_serving/prometheu The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: - `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds. -- `vllm:prompt_tokens_total` - Prompt tokens. -- `vllm:generation_tokens_total` - Generation tokens. +- `vllm:prompt_tokens` - Prompt tokens. +- `vllm:generation_tokens` - Generation tokens. - `vllm:time_per_output_token_seconds` - Inter-token latency (Time Per Output Token, TPOT) in seconds. - `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds. - `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in the RUNNING, WAITING, and SWAPPED states. - `vllm:gpu_cache_usage_perc` - Percentage of used cache blocks by vLLM. - `vllm:request_prompt_tokens` - Request prompt length. - `vllm:request_generation_tokens` - Request generation length. -- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached. +- `vllm:request_success` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached. - `vllm:request_queue_time_seconds` - Queue time. - `vllm:request_prefill_time_seconds` - Requests prefill time. - `vllm:request_decode_time_seconds` - Requests decode time. @@ -571,9 +571,9 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_draft_acceptance_rate` (Gauge) - `vllm:spec_decode_efficiency` (Gauge) -- `vllm:spec_decode_num_accepted_tokens_total` (Counter) -- `vllm:spec_decode_num_draft_tokens_total` (Counter) -- `vllm:spec_decode_num_emitted_tokens_total` (Counter) +- `vllm:spec_decode_num_accepted_tokens` (Counter) +- `vllm:spec_decode_num_draft_tokens` (Counter) +- `vllm:spec_decode_num_emitted_tokens` (Counter) There is a PR under review () to add "prompt lookup (ngram)" speculative decoding to v1. Other techniques will follow. We should diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 44aaa65218cc4..48341d199cb80 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],
[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | -| deep gemm+triton2 | standard,
batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],
[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | @@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor | backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | |---------|-----------------------------------------|----------------------------------------------| | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | -| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts` | +| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts` | | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | diff --git a/docs/features/README.md b/docs/features/README.md index 5faf3768f3214..684802301a44f 100644 --- a/docs/features/README.md +++ b/docs/features/README.md @@ -54,7 +54,7 @@ th:not(:first-child) { | beam-search | ✅ | ✅ | ✅ | [❌](https://github.com/vllm-project/vllm/issues/6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](https://github.com/vllm-project/vllm/issues/7968) | ❔ | ✅ | ✅ | | | [prompt-embeds](prompt_embeds.md) | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ | -\* Chunked prefill and prefix caching are only applicable to last-token pooling. +\* Chunked prefill and prefix caching are only applicable to last-token or all pooling with causal attention. ^ LoRA is only applicable to the language backbone of multimodal models. ### Feature x Hardware diff --git a/docs/features/mooncake_connector_usage.md b/docs/features/mooncake_connector_usage.md new file mode 100644 index 0000000000000..653ea29ad9433 --- /dev/null +++ b/docs/features/mooncake_connector_usage.md @@ -0,0 +1,58 @@ +# MooncakeConnector Usage Guide + +## About Mooncake + +Mooncake aims to enhance the inference efficiency of large language models (LLMs), especially in slow object storage environments, by constructing a multi-level caching pool on high-speed interconnected DRAM/SSD resources. Compared to traditional caching systems, Mooncake utilizes (GPUDirect) RDMA technology to transfer data directly in a zero-copy manner, while maximizing the use of multi-NIC resources on a single machine. + +For more details about Mooncake, please refer to [Mooncake project](https://github.com/kvcache-ai/Mooncake) and [Mooncake documents](https://kvcache-ai.github.io/Mooncake/). + +## Prerequisites + +### Installation + +Install mooncake through pip: `uv pip install mooncake-transfer-engine`. + +Refer to [Mooncake official repository](https://github.com/kvcache-ai/Mooncake) for more installation instructions + +## Usage + +### Prefiller Node (192.168.0.2) + +```bash +vllm serve Qwen/Qwen2.5-7B-Instruct --port 8010 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_producer"}' +``` + +### Decoder Node (192.168.0.3) + +```bash +vllm serve Qwen/Qwen2.5-7B-Instruct --port 8020 --kv-transfer-config '{"kv_connector":"MooncakeConnector","kv_role":"kv_consumer"}' +``` + +### Proxy + +```bash +python tests/v1/kv_connector/nixl_integration/toy_proxy_server.py --prefiller-host 192.168.0.2 --prefiller-port 8010 --decoder-host 192.168.0.3 --decoder-port 8020 +``` + +> NOTE: The Mooncake Connector currently uses the proxy from nixl_integration. This will be replaced with a self-developed proxy in the future. + +Now you can send requests to the proxy server through port 8000. + +## Environment Variables + +- `VLLM_MOONCAKE_BOOTSTRAP_PORT`: Port for Mooncake bootstrap server + - Default: 8998 + - Required only for prefiller instances + - Each vLLM worker needs a unique port on its host; using the same port number across different hosts is fine + - For TP/DP deployments, each worker's port on a node is computed as: base_port + dp_rank * tp_size + tp_rank + - Used for the decoder notifying the prefiller + +- `VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT`: Timeout (in seconds) for automatically releasing the prefiller’s KV cache for a particular request. (Optional) + - Default: 480 + - If a request is aborted and the decoder has not yet notified the prefiller, the prefill instance will release its KV-cache blocks after this timeout to avoid holding them indefinitely. + +## KV Role Options + +- **kv_producer**: For prefiller instances that generate KV caches +- **kv_consumer**: For decoder instances that consume KV caches from prefiller +- **kv_both**: Enables symmetric functionality where the connector can act as both producer and consumer. This provides flexibility for experimental setups and scenarios where the role distinction is not predetermined. diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index 4656ee43ea251..2b25dc7666c37 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -795,14 +795,12 @@ The following example demonstrates how to pass image embeddings to the OpenAI se ??? code ```python + from vllm.utils.serial_utils import tensor2base64 + image_embedding = torch.load(...) grid_thw = torch.load(...) # Required by Qwen/Qwen2-VL-2B-Instruct - buffer = io.BytesIO() - torch.save(image_embedding, buffer) - buffer.seek(0) - binary_data = buffer.read() - base64_image_embedding = base64.b64encode(binary_data).decode('utf-8') + base64_image_embedding = tensor2base64(image_embedding) client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") diff --git a/docs/getting_started/installation/cpu.apple.inc.md b/docs/getting_started/installation/cpu.apple.inc.md index 4dc707d5f9a14..9f1f6e3821397 100644 --- a/docs/getting_started/installation/cpu.apple.inc.md +++ b/docs/getting_started/installation/cpu.apple.inc.md @@ -4,9 +4,6 @@ vLLM has experimental support for macOS with Apple Silicon. For now, users must Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - # --8<-- [end:installation] # --8<-- [start:requirements] @@ -20,6 +17,8 @@ Currently the CPU implementation for macOS supports FP32 and FP16 datatypes. # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] +Currently, there are no pre-built Apple silicon CPU wheels. + # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] @@ -78,6 +77,8 @@ uv pip install -e . # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] +Currently, there are no pre-built Arm silicon CPU images. + # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] diff --git a/docs/getting_started/installation/cpu.arm.inc.md b/docs/getting_started/installation/cpu.arm.inc.md index 9cae9ed1a212e..156f31f633d57 100644 --- a/docs/getting_started/installation/cpu.arm.inc.md +++ b/docs/getting_started/installation/cpu.arm.inc.md @@ -1,11 +1,6 @@ # --8<-- [start:installation] -vLLM has been adapted to work on ARM64 CPUs with NEON support, leveraging the CPU backend initially developed for the x86 platform. - -ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. +vLLM offers basic model inferencing and serving on Arm CPU platform, with support NEON, data types FP32, FP16 and BF16. # --8<-- [end:installation] # --8<-- [start:requirements] @@ -20,6 +15,23 @@ ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] +Pre-built vLLM wheels for Arm are available since version 0.11.2. These wheels contain pre-compiled C++ binaries. +Please replace `` in the commands below with a specific version string (e.g., `0.11.2`). + +```bash +uv pip install --pre vllm==+cpu --extra-index-url https://wheels.vllm.ai/%2Bcpu/ +``` + +??? console "pip" + ```bash + pip install --pre vllm==+cpu --extra-index-url https://wheels.vllm.ai/%2Bcpu/ + ``` + +The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. + +!!! note + Nightly wheels are currently unsupported for this architecture. (e.g. to bisect the behavior change, performance regression). + # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] @@ -69,6 +81,8 @@ Testing has been conducted on AWS Graviton3 instances for compatibility. # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] +Currently, there are no pre-built Arm CPU images. + # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] ```bash diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 4b68cb4811789..210f720e2d92a 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -46,11 +46,25 @@ vLLM is a Python library that supports the following CPU variants. Select your C ### Pre-built wheels -Please refer to the instructions for [pre-built wheels on GPU](./gpu.md#pre-built-wheels). - When specifying the index URL, please make sure to use the `cpu` variant subdirectory. For example, the nightly build index is: `https://wheels.vllm.ai/nightly/cpu/`. +=== "Intel/AMD x86" + + --8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-wheels" + +=== "ARM AArch64" + + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:pre-built-wheels" + +=== "Apple silicon" + + --8<-- "docs/getting_started/installation/cpu.apple.inc.md:pre-built-wheels" + +=== "IBM Z (S390X)" + + --8<-- "docs/getting_started/installation/cpu.s390x.inc.md:pre-built-wheels" + ### Build wheel from source #### Set up using Python-only build (without compilation) {#python-only-build} @@ -87,6 +101,18 @@ VLLM_USE_PRECOMPILED=1 VLLM_PRECOMPILED_WHEEL_VARIANT=cpu VLLM_TARGET_DEVICE=cpu --8<-- "docs/getting_started/installation/cpu.x86.inc.md:pre-built-images" +=== "ARM AArch64" + + --8<-- "docs/getting_started/installation/cpu.arm.inc.md:pre-built-images" + +=== "Apple silicon" + + --8<-- "docs/getting_started/installation/cpu.apple.inc.md:pre-built-images" + +=== "IBM Z (S390X)" + + --8<-- "docs/getting_started/installation/cpu.s390x.inc.md:pre-built-images" + ### Build image from source === "Intel/AMD x86" diff --git a/docs/getting_started/installation/cpu.s390x.inc.md b/docs/getting_started/installation/cpu.s390x.inc.md index c2163139a7c5d..4984c87c17b01 100644 --- a/docs/getting_started/installation/cpu.s390x.inc.md +++ b/docs/getting_started/installation/cpu.s390x.inc.md @@ -4,9 +4,6 @@ vLLM has experimental support for s390x architecture on IBM Z platform. For now, Currently, the CPU implementation for s390x architecture supports FP32 datatype only. -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. - # --8<-- [end:installation] # --8<-- [start:requirements] @@ -21,6 +18,8 @@ Currently, the CPU implementation for s390x architecture supports FP32 datatype # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] +Currently, there are no pre-built IBM Z CPU wheels. + # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] @@ -69,6 +68,8 @@ Execute the following commands to build and install vLLM from source. # --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] +Currently, there are no pre-built IBM Z CPU images. + # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] diff --git a/docs/getting_started/installation/cpu.x86.inc.md b/docs/getting_started/installation/cpu.x86.inc.md index 310f179cb89ca..1fad7f4338822 100644 --- a/docs/getting_started/installation/cpu.x86.inc.md +++ b/docs/getting_started/installation/cpu.x86.inc.md @@ -17,6 +17,8 @@ vLLM supports basic model inferencing and serving on x86 CPU platform, with data # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] +Currently, there are no pre-built x86 CPU wheels. + # --8<-- [end:pre-built-wheels] # --8<-- [start:build-wheel-from-source] diff --git a/docs/getting_started/installation/gpu.rocm.inc.md b/docs/getting_started/installation/gpu.rocm.inc.md index c80ba9478f6be..21120cc6fcd98 100644 --- a/docs/getting_started/installation/gpu.rocm.inc.md +++ b/docs/getting_started/installation/gpu.rocm.inc.md @@ -5,9 +5,6 @@ vLLM supports AMD GPUs with ROCm 6.3 or above, and torch 2.8.0 and above. !!! tip [Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm. -!!! warning - There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. - # --8<-- [end:installation] # --8<-- [start:requirements] diff --git a/docs/getting_started/installation/gpu.xpu.inc.md b/docs/getting_started/installation/gpu.xpu.inc.md index 620a660a240ed..7e9c6a2b9de07 100644 --- a/docs/getting_started/installation/gpu.xpu.inc.md +++ b/docs/getting_started/installation/gpu.xpu.inc.md @@ -2,9 +2,6 @@ vLLM initially supports basic model inference and serving on Intel GPU platform. -!!! warning - There are no pre-built wheels for this device, so you need build vLLM from source. Or you can use pre-built images which are based on vLLM released versions. - # --8<-- [end:installation] # --8<-- [start:requirements] diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 040107c11efcf..96d5ec25c0064 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -711,7 +711,6 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + IE | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | | `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + IE+ | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | | `Phi4MMForCausalLM` | Phi-4-multimodal | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | -| `Phi4MultimodalForCausalLM` | Phi-4-multimodal (HF Transformers) | T + I+ / T + A+ / I+ + A+ | `microsoft/Phi-4-multimodal-instruct` (with revision `refs/pr/70`), etc. | ✅︎ | ✅︎ | | `PixtralForConditionalGeneration` | Ministral 3 (Mistral format), Mistral 3 (Mistral format), Mistral Large 3 (Mistral format), Pixtral (Mistral format) | T + I+ | `mistralai/Ministral-3-3B-Instruct-2512`, `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Mistral-Large-3-675B-Instruct-2512` `mistralai/Pixtral-12B-2409` etc. | | ✅︎ | | `QwenVLForConditionalGeneration`^ | Qwen-VL | T + IE+ | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | | `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A+ | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index dc5c6202fa57b..2f3564b597556 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -23,31 +23,23 @@ def create_test_prompts( # this is an example of using quantization without LoRA ( "My name is", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), None, ), # the next three examples use quantization with LoRA ( "my name is", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("lora-test-1", 1, lora_path), ), ( "The capital of USA is", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("lora-test-2", 1, lora_path), ), ( "The capital of France is", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("lora-test-3", 1, lora_path), ), ] diff --git a/examples/offline_inference/multilora_inference.py b/examples/offline_inference/multilora_inference.py index 5e5da2c0144c9..92021f9fb226c 100644 --- a/examples/offline_inference/multilora_inference.py +++ b/examples/offline_inference/multilora_inference.py @@ -27,9 +27,7 @@ def create_test_prompts( return [ ( "A robot may not injure a human being", - SamplingParams( - temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128 - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), None, ), ( @@ -41,22 +39,12 @@ def create_test_prompts( ), ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - SamplingParams( - temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128, - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("sql-lora", 1, lora_path), ), ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - SamplingParams( - temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=128, - ), + SamplingParams(temperature=0.0, logprobs=1, max_tokens=128), LoRARequest("sql-lora2", 2, lora_path), ), ] diff --git a/examples/online_serving/prompt_embed_inference_with_openai_client.py b/examples/online_serving/prompt_embed_inference_with_openai_client.py index 0bbe4b8f5ee9b..889be6820e70a 100644 --- a/examples/online_serving/prompt_embed_inference_with_openai_client.py +++ b/examples/online_serving/prompt_embed_inference_with_openai_client.py @@ -28,13 +28,11 @@ Dependencies: - openai """ -import base64 -import io - -import torch import transformers from openai import OpenAI +from vllm.utils.serial_utils import tensor2base64 + def main(): client = OpenAI( @@ -58,11 +56,7 @@ def main(): prompt_embeds = embedding_layer(token_ids).squeeze(0) # Prompt embeddings - buffer = io.BytesIO() - torch.save(prompt_embeds, buffer) - buffer.seek(0) - binary_data = buffer.read() - encoded_embeds = base64.b64encode(binary_data).decode("utf-8") + encoded_embeds = tensor2base64(prompt_embeds) completion = client.completions.create( model=model_name, diff --git a/examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py b/examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py index 47c2c5030078c..a7ab7e73e7d42 100644 --- a/examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py +++ b/examples/pooling/embed/openai_chat_embedding_client_for_multimodal.py @@ -150,7 +150,8 @@ def run_siglip(client: OpenAI, model: str): Start the server using: vllm serve google/siglip-base-patch16-224 \ - --runner pooling + --runner pooling \ + --chat-template template_basic.jinja """ response = create_chat_embeddings( diff --git a/requirements/common.txt b/requirements/common.txt index 8b9e6b935bd20..f18560b98d16c 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -46,6 +46,7 @@ scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu pybase64 # fast base64 implementation cbor2 # Required for cross-language serialization of hashable objects +ijson # Required for mistral streaming tool parser setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 diff --git a/requirements/cpu-build.txt b/requirements/cpu-build.txt index e18e0825fc428..1ea401a04a12c 100644 --- a/requirements/cpu-build.txt +++ b/requirements/cpu-build.txt @@ -3,7 +3,6 @@ ninja packaging>=24.2 setuptools>=77.0.3,<81.0.0 setuptools-scm>=8 ---extra-index-url https://download.pytorch.org/whl/cpu torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x" torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64" scons; platform_machine == "aarch64" # needed to build Arm Compute Library (ACL) diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 21571be479c83..7a670812e8943 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -4,7 +4,6 @@ numba == 0.61.2; platform_machine != "s390x" # Required for N-gram speculative decoding # Dependencies for CPUs ---extra-index-url https://download.pytorch.org/whl/cpu torch==2.9.1+cpu; platform_machine == "x86_64" or platform_machine == "s390x" torch==2.9.1; platform_system == "Darwin" or platform_machine == "ppc64le" or platform_machine == "aarch64" diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 53b012372be8e..7b2c665448a3b 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -42,6 +42,6 @@ tritonclient==2.51.0 numba == 0.61.2 # Required for N-gram speculative decoding numpy -runai-model-streamer[s3,gcs]==0.15.0 +runai-model-streamer[s3,gcs]==0.15.3 fastsafetensors>=0.1.10 pydantic>=2.12 # 2.11 leads to error on python 3.13 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index abbd33d6e1240..05b9a21791c92 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -12,7 +12,7 @@ tensorizer==2.10.1 packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -runai-model-streamer[s3,gcs]==0.15.0 +runai-model-streamer[s3,gcs]==0.15.3 conch-triton-kernels==1.2.1 timm>=1.0.17 fastsafetensors @ git+https://github.com/foundation-model-stack/fastsafetensors.git@d6f998a03432b2452f8de2bb5cefb5af9795d459 diff --git a/requirements/test.in b/requirements/test.in index da7a7db1f00c9..dfae5b75821f8 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -51,7 +51,7 @@ tritonclient==2.51.0 arctic-inference == 0.1.1 # Required for suffix decoding test numba == 0.61.2 # Required for N-gram speculative decoding numpy -runai-model-streamer[s3,gcs]==0.15.0 +runai-model-streamer[s3,gcs]==0.15.3 fastsafetensors>=0.1.10 pydantic>=2.12 # 2.11 leads to error on python 3.13 decord==0.6.0 diff --git a/requirements/test.txt b/requirements/test.txt index c5f103b8b0d78..571194e05c1ba 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -965,11 +965,11 @@ rsa==4.9.1 # via google-auth rtree==1.4.0 # via torchgeo -runai-model-streamer==0.15.0 +runai-model-streamer==0.15.3 # via -r requirements/test.in -runai-model-streamer-gcs==0.15.0 +runai-model-streamer-gcs==0.15.3 # via runai-model-streamer -runai-model-streamer-s3==0.15.0 +runai-model-streamer-s3==0.15.3 # via runai-model-streamer s3transfer==0.10.3 # via boto3 diff --git a/setup.py b/setup.py index 5b7d12bb373e3..6fcb6653bc4a3 100644 --- a/setup.py +++ b/setup.py @@ -346,10 +346,13 @@ class precompiled_wheel_utils: The order of preference is: 1. user-specified wheel location (can be either local or remote, via VLLM_PRECOMPILED_WHEEL_LOCATION) - 2. user-specified variant from nightly repo (current main commit via - VLLM_PRECOMPILED_WHEEL_VARIANT) + 2. user-specified variant (VLLM_PRECOMPILED_WHEEL_VARIANT) from nightly repo 3. the variant corresponding to VLLM_MAIN_CUDA_VERSION from nightly repo - 4. the default variant from nightly repo (current main commit) + 4. the default variant from nightly repo + + If downloading from the nightly repo, the commit can be specified via + VLLM_PRECOMPILED_WHEEL_COMMIT; otherwise, the head commit in the main branch + is used. """ wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None) if wheel_location is not None: @@ -362,10 +365,13 @@ class precompiled_wheel_utils: # try to fetch the wheel metadata from the nightly wheel repo main_variant = "cu" + envs.VLLM_MAIN_CUDA_VERSION.replace(".", "") variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", main_variant) - commit = os.getenv( - "VLLM_PRECOMPILED_WHEEL_COMMIT", - precompiled_wheel_utils.get_base_commit_in_main_branch(), - ) + commit = os.getenv("VLLM_PRECOMPILED_WHEEL_COMMIT", "").lower() + if not commit or len(commit) != 40: + print( + f"VLLM_PRECOMPILED_WHEEL_COMMIT not valid: {commit}" + ", trying to fetch base commit in main branch" + ) + commit = precompiled_wheel_utils.get_base_commit_in_main_branch() print(f"Using precompiled wheel commit {commit} with variant {variant}") try_default = False wheels, repo_url, download_filename = None, None, None @@ -461,14 +467,22 @@ class precompiled_wheel_utils: "vllm/cumem_allocator.abi3.so", ] - compiled_regex = re.compile( + flash_attn_regex = re.compile( r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" ) + triton_kernels_regex = re.compile( + r"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py" + ) file_members = list( filter(lambda x: x.filename in files_to_copy, wheel.filelist) ) file_members += list( - filter(lambda x: compiled_regex.match(x.filename), wheel.filelist) + filter(lambda x: flash_attn_regex.match(x.filename), wheel.filelist) + ) + file_members += list( + filter( + lambda x: triton_kernels_regex.match(x.filename), wheel.filelist + ) ) for file in file_members: @@ -494,10 +508,6 @@ class precompiled_wheel_utils: @staticmethod def get_base_commit_in_main_branch() -> str: - # Force to use the nightly wheel. This is mainly used for CI testing. - if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: - return "nightly" - try: # Get the latest commit hash of the upstream main branch. resp_json = subprocess.check_output( @@ -508,6 +518,7 @@ class precompiled_wheel_utils: ] ).decode("utf-8") upstream_main_commit = json.loads(resp_json)["sha"] + print(f"Upstream main branch latest commit: {upstream_main_commit}") # In Docker build context, .git may be immutable or missing. if envs.VLLM_DOCKER_BUILD_CONTEXT: @@ -648,7 +659,7 @@ def get_vllm_version() -> str: if envs.VLLM_TARGET_DEVICE == "empty": version += f"{sep}empty" elif _is_cuda(): - if envs.VLLM_USE_PRECOMPILED: + if envs.VLLM_USE_PRECOMPILED and not envs.VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX: version += f"{sep}precompiled" else: cuda_version = str(get_nvcc_cuda_version()) @@ -786,7 +797,7 @@ setup( "bench": ["pandas", "matplotlib", "seaborn", "datasets"], "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], - "runai": ["runai-model-streamer[s3,gcs] >= 0.15.0"], + "runai": ["runai-model-streamer[s3,gcs] >= 0.15.3"], "audio": [ "librosa", "soundfile", diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 9e912c6d810d2..8dd6959a01d09 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -392,39 +392,48 @@ def test_pass_config_deprecation(caplog_vllm): assert "enable_fusion is deprecated" in caplog_vllm.text assert config.fuse_norm_quant is True assert config.fuse_act_quant is True - assert config.enable_fusion is None + assert config.enable_fusion is True # Test enable_attn_fusion -> fuse_attn_quant caplog_vllm.clear() config = PassConfig(enable_attn_fusion=True) assert "enable_attn_fusion is deprecated" in caplog_vllm.text assert config.fuse_attn_quant is True - assert config.enable_attn_fusion is None + assert config.enable_attn_fusion is True # Test enable_noop -> eliminate_noops caplog_vllm.clear() config = PassConfig(enable_noop=True) assert "enable_noop is deprecated" in caplog_vllm.text assert config.eliminate_noops is True - assert config.enable_noop is None + assert config.enable_noop is True # Test enable_sequence_parallelism -> enable_sp caplog_vllm.clear() config = PassConfig(enable_sequence_parallelism=True) assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text assert config.enable_sp is True - assert config.enable_sequence_parallelism is None + assert config.enable_sequence_parallelism is True # Test enable_async_tp -> fuse_gemm_comms caplog_vllm.clear() config = PassConfig(enable_async_tp=True) assert "enable_async_tp is deprecated" in caplog_vllm.text assert config.fuse_gemm_comms is True - assert config.enable_async_tp is None + assert config.enable_async_tp is True # Test enable_fi_allreduce_fusion -> fuse_allreduce_rms caplog_vllm.clear() config = PassConfig(enable_fi_allreduce_fusion=True) assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text assert config.fuse_allreduce_rms is True - assert config.enable_fi_allreduce_fusion is None + assert config.enable_fi_allreduce_fusion is True + + # Test hash consistency + config_old = PassConfig(enable_fusion=True) + config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True) + assert config_old.compute_hash() == config_new.compute_hash() + + config_old = PassConfig(enable_async_tp=True) + config_new = PassConfig(fuse_gemm_comms=True) + assert config_old.compute_hash() == config_new.compute_hash() diff --git a/tests/distributed/test_eplb_spec_decode.py b/tests/distributed/test_eplb_spec_decode.py index 868cc702866e2..22977ce94404b 100644 --- a/tests/distributed/test_eplb_spec_decode.py +++ b/tests/distributed/test_eplb_spec_decode.py @@ -6,6 +6,7 @@ import lm_eval import pytest from tests.utils import large_gpu_mark +from vllm.platforms import current_platform def get_model_args( @@ -45,6 +46,12 @@ def get_model_args( return model_args +pytestmark = pytest.mark.skipif( + current_platform.is_rocm(), + reason="EPLB with Spec Decode is a work in progress on ROCm.", +) + + @pytest.mark.parametrize( "model_setup", [ diff --git a/tests/entrypoints/openai/test_messages.py b/tests/entrypoints/openai/test_messages.py index 3e390ad496428..b804a1a7a841a 100644 --- a/tests/entrypoints/openai/test_messages.py +++ b/tests/entrypoints/openai/test_messages.py @@ -69,9 +69,20 @@ async def test_anthropic_streaming(client: anthropic.AsyncAnthropic): stream=True, ) + first_chunk = None + chunk_count = 0 async for chunk in resp: + chunk_count += 1 + if first_chunk is None and chunk.type == "message_start": + first_chunk = chunk print(chunk.model_dump_json()) + assert chunk_count > 0 + assert first_chunk is not None, "message_start chunk was never observed" + assert first_chunk.usage is not None, "first chunk should include usage stats" + assert first_chunk.usage["output_tokens"] == 0 + assert first_chunk.usage["input_tokens"] > 5 + @pytest.mark.asyncio async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic): diff --git a/tests/entrypoints/openai/test_vision_embeds.py b/tests/entrypoints/openai/test_vision_embeds.py index a6593c5b05e2e..42d9fe4840bbe 100644 --- a/tests/entrypoints/openai/test_vision_embeds.py +++ b/tests/entrypoints/openai/test_vision_embeds.py @@ -2,64 +2,47 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 -import io import numpy as np import pytest import requests import torch +from vllm.utils.serial_utils import tensor2base64 + from ...utils import RemoteOpenAIServer -MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11" -DTYPE = "float16" - -def _terratorch_dummy_inputs(model_name: str): +def _terratorch_dummy_messages(): pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) - buffer_tiff = io.BytesIO() - torch.save(pixel_values, buffer_tiff) - buffer_tiff.seek(0) - binary_data = buffer_tiff.read() - base64_tensor_embedding = base64.b64encode(binary_data).decode("utf-8") - - buffer_coord = io.BytesIO() - torch.save(location_coords, buffer_coord) - buffer_coord.seek(0) - binary_data = buffer_coord.read() - base64_coord_embedding = base64.b64encode(binary_data).decode("utf-8") - - return { - "model": model_name, - "additional_data": {"prompt_token_ids": [1]}, - "encoding_format": "base64", - "messages": [ - { - "role": "user", - "content": [ - { - "type": "image_embeds", - "image_embeds": { - "pixel_values": base64_tensor_embedding, - "location_coords": base64_coord_embedding, - }, - } - ], - } - ], - } + return [ + { + "role": "user", + "content": [ + { + "type": "image_embeds", + "image_embeds": { + "pixel_values": tensor2base64(pixel_values), + "location_coords": tensor2base64(location_coords), + }, + } + ], + } + ] -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_single_request(model_name: str): +@pytest.mark.parametrize( + "model_name", ["ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] +) +def test_single_request(model_name: str): args = [ "--runner", "pooling", # use half precision for speed and memory savings in CI environment "--dtype", - DTYPE, + "float16", "--enforce-eager", "--trust-remote-code", "--max-num-seqs", @@ -70,11 +53,15 @@ async def test_single_request(model_name: str): "--enable-mm-embeds", ] - with RemoteOpenAIServer(MODEL_NAME, args) as server: - prompt = _terratorch_dummy_inputs(model_name) - - # test single pooling - response = requests.post(server.url_for("pooling"), json=prompt) + with RemoteOpenAIServer(model_name, args) as server: + response = requests.post( + server.url_for("pooling"), + json={ + "model": model_name, + "messages": _terratorch_dummy_messages(), + "encoding_format": "base64", + }, + ) response.raise_for_status() output = response.json()["data"][0]["data"] diff --git a/tests/entrypoints/pooling/classify/test_offline.py b/tests/entrypoints/pooling/classify/test_offline.py index 1063c3b6b755c..a07fcd372721a 100644 --- a/tests/entrypoints/pooling/classify/test_offline.py +++ b/tests/entrypoints/pooling/classify/test_offline.py @@ -61,11 +61,8 @@ def test_pooling_params(llm: LLM): @pytest.mark.skip_global_cleanup -def test_encode_api(llm: LLM): - # chunked prefill does not support all pooling - err_msg = "pooling_task must be one of.+" - with pytest.raises(ValueError, match=err_msg): - llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) +def test_token_classify(llm: LLM): + llm.encode(prompts, pooling_task="token_classify", use_tqdm=False) def test_score_api(llm: LLM): diff --git a/tests/entrypoints/pooling/classify/test_online.py b/tests/entrypoints/pooling/classify/test_online.py index 6fef688586955..1a6c33b455e65 100644 --- a/tests/entrypoints/pooling/classify/test_online.py +++ b/tests/entrypoints/pooling/classify/test_online.py @@ -255,21 +255,21 @@ async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str): - # token_classify uses ALL pooling, which does not support chunked prefill. task = "token_classify" + input_text = ["This product was excellent and exceeded my expectations"] response = requests.post( server.url_for("pooling"), json={ "model": model_name, - "input": "test", + "input": input_text, "encoding_format": "float", "task": task, }, ) - assert response.json()["error"]["type"] == "BadRequestError" - assert response.json()["error"]["message"].startswith( - f"Task {task} is not supported" - ) + poolings = PoolingResponse.model_validate(response.json()) + assert len(poolings.data) == 1 + assert len(poolings.data[0].data) == 8 + assert len(poolings.data[0].data[0]) == 2 @pytest.mark.asyncio diff --git a/tests/entrypoints/pooling/embed/test_offline.py b/tests/entrypoints/pooling/embed/test_offline.py index f5eab4c29ae18..12b47b1a08a8b 100644 --- a/tests/entrypoints/pooling/embed/test_offline.py +++ b/tests/entrypoints/pooling/embed/test_offline.py @@ -42,7 +42,7 @@ def llm(): @pytest.mark.skip_global_cleanup -def test_encode_api(llm: LLM): +def test_token_embed(llm: LLM): outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False) multi_vector = outputs[0].outputs.data assert multi_vector.shape == (11, 384) diff --git a/tests/entrypoints/pooling/reward/test_offline.py b/tests/entrypoints/pooling/reward/test_offline.py index 0255704cecd94..b061b55145155 100644 --- a/tests/entrypoints/pooling/reward/test_offline.py +++ b/tests/entrypoints/pooling/reward/test_offline.py @@ -36,6 +36,13 @@ def llm(): cleanup_dist_env_and_memory() +@pytest.mark.skip_global_cleanup +def test_config(llm: LLM): + vllm_config = llm.llm_engine.vllm_config + assert vllm_config.cache_config.enable_prefix_caching + assert vllm_config.scheduler_config.enable_chunked_prefill + + def test_pooling_params(llm: LLM): def get_outputs(use_activation): outputs = llm.reward( diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 03a0c058ea690..75be34820bcd7 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -29,6 +29,7 @@ from vllm.multimodal.utils import ( encode_video_base64, ) from vllm.tokenizers import MistralTokenizer, get_tokenizer +from vllm.utils.serial_utils import tensor2base64 from ..models.registry import HF_EXAMPLE_MODELS from ..utils import VLLM_PATH @@ -85,11 +86,6 @@ def phi3v_model_config_image_embeds(): ) -@pytest.fixture(scope="module") -def phi3v_tokenizer(): - return get_tokenizer(PHI3V_MODEL_ID) - - @pytest.fixture(scope="function") def qwen2_audio_model_config(): return ModelConfig( @@ -115,11 +111,6 @@ def audio_embeds_model_config(): ) -@pytest.fixture(scope="module") -def qwen2_audio_tokenizer(): - return get_tokenizer(QWEN2AUDIO_MODEL_ID) - - @pytest.fixture(scope="function") def qwen25omni_model_config_mm_interleaved(): return ModelConfig( @@ -134,11 +125,6 @@ def qwen25omni_model_config_mm_interleaved(): ) -@pytest.fixture(scope="module") -def qwen25omni_tokenizer(): - return get_tokenizer(QWEN25OMNI_MODEL_ID) - - @pytest.fixture(scope="function") def mistral_model_config(): return ModelConfig( @@ -150,11 +136,6 @@ def mistral_model_config(): ) -@pytest.fixture(scope="module") -def mistral_tokenizer(): - return get_tokenizer(MISTRAL_MODEL_ID) - - @pytest.fixture(scope="module") def image_url(): image = ImageAsset("cherry_blossom") @@ -239,7 +220,6 @@ def _assert_mm_data_inputs( def test_parse_chat_messages_single_image( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -253,7 +233,6 @@ def test_parse_chat_messages_single_image( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -266,7 +245,6 @@ def test_parse_chat_messages_single_image( def test_parse_chat_messages_single_image_with_uuid( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -287,7 +265,6 @@ def test_parse_chat_messages_single_image_with_uuid( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -300,7 +277,6 @@ def test_parse_chat_messages_single_image_with_uuid( def test_parse_chat_messages_single_empty_image_with_uuid( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -319,7 +295,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -332,7 +307,6 @@ def test_parse_chat_messages_single_empty_image_with_uuid( def test_parse_chat_messages_single_image_with_bad_uuid_format( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -354,7 +328,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -367,7 +340,6 @@ def test_parse_chat_messages_single_image_with_bad_uuid_format( def test_parse_chat_messages_multiple_images_with_uuids( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -397,7 +369,6 @@ def test_parse_chat_messages_multiple_images_with_uuids( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -413,7 +384,6 @@ def test_parse_chat_messages_multiple_images_with_uuids( def test_parse_chat_messages_multiple_empty_images_with_uuids( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -439,7 +409,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -455,7 +424,6 @@ def test_parse_chat_messages_multiple_empty_images_with_uuids( def test_parse_chat_messages_mixed_empty_images_with_uuids( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -483,7 +451,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -500,7 +467,6 @@ def test_parse_chat_messages_mixed_empty_images_with_uuids( @pytest.mark.asyncio async def test_parse_chat_messages_single_image_with_uuid_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -519,7 +485,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -533,7 +498,6 @@ async def test_parse_chat_messages_single_image_with_uuid_async( @pytest.mark.asyncio async def test_parse_chat_messages_empty_image_with_uuid_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -552,7 +516,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -566,7 +529,6 @@ async def test_parse_chat_messages_empty_image_with_uuid_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_with_uuids_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -592,7 +554,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -609,7 +570,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid1 = "my_uuid_1" @@ -635,7 +595,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -652,7 +611,6 @@ async def test_parse_chat_messages_multiple_empty_images_with_uuids_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid2 = "my_uuid_2" @@ -676,7 +634,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -692,7 +649,6 @@ async def test_parse_chat_messages_multiple_images_with_partial_uuids_async( def test_parse_chat_messages_empty_system( mistral_model_config, - mistral_tokenizer, ): # Test string format conversation, _, _ = parse_chat_messages( @@ -704,7 +660,6 @@ def test_parse_chat_messages_empty_system( }, ], mistral_model_config, - mistral_tokenizer, content_format="string", ) assert conversation == [ @@ -722,7 +677,6 @@ def test_parse_chat_messages_empty_system( }, ], mistral_model_config, - mistral_tokenizer, content_format="openai", ) assert conversation == [ @@ -734,7 +688,6 @@ def test_parse_chat_messages_empty_system( @pytest.mark.asyncio async def test_parse_chat_messages_single_image_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_future, mm_uuids = parse_chat_messages_futures( @@ -748,7 +701,6 @@ async def test_parse_chat_messages_single_image_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -761,7 +713,6 @@ async def test_parse_chat_messages_single_image_async( def test_parse_chat_messages_multiple_images( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -779,7 +730,6 @@ def test_parse_chat_messages_multiple_images( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -795,7 +745,6 @@ def test_parse_chat_messages_multiple_images( def test_parse_chat_messages_empty_pil_image_with_uuid( phi3v_model_config, - phi3v_tokenizer, ): uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( @@ -809,7 +758,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -825,7 +773,6 @@ def test_parse_chat_messages_empty_pil_image_with_uuid( def test_parse_chat_messages_empty_image_embeds_with_uuid( phi3v_model_config_image_embeds, - phi3v_tokenizer, ): uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( @@ -839,7 +786,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( } ], phi3v_model_config_image_embeds, - phi3v_tokenizer, content_format="string", ) @@ -857,7 +803,6 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid( def test_parse_chat_messages_empty_audio_embeds_with_uuid( audio_embeds_model_config, - qwen2_audio_tokenizer, ): """Test audio_embeds with UUID (no actual embeds data).""" uuid = "test-audio-uuid-123" @@ -873,7 +818,6 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid( } ], audio_embeds_model_config, - qwen2_audio_tokenizer, content_format="string", ) @@ -889,11 +833,8 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid( def test_parse_chat_messages_audio_embeds_with_string( audio_embeds_model_config, - qwen2_audio_tokenizer, ): """Test audio_embeds with base64 string embedding data.""" - import base64 - import io import torch @@ -901,11 +842,7 @@ def test_parse_chat_messages_audio_embeds_with_string( audio_embedding = torch.randn(1, 128, 768) # Encode it as base64 - buffer = io.BytesIO() - torch.save(audio_embedding, buffer) - buffer.seek(0) - binary_data = buffer.read() - base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8") + base64_audio_embedding = tensor2base64(audio_embedding) conversation, mm_data, mm_uuids = parse_chat_messages( [ @@ -921,7 +858,6 @@ def test_parse_chat_messages_audio_embeds_with_string( } ], audio_embeds_model_config, - qwen2_audio_tokenizer, content_format="string", ) @@ -939,11 +875,8 @@ def test_parse_chat_messages_audio_embeds_with_string( @pytest.mark.asyncio async def test_parse_chat_messages_audio_embeds_async( audio_embeds_model_config, - qwen2_audio_tokenizer, ): """Test audio_embeds with async futures.""" - import base64 - import io import torch @@ -951,11 +884,7 @@ async def test_parse_chat_messages_audio_embeds_async( audio_embedding = torch.randn(1, 128, 768) # Encode it as base64 - buffer = io.BytesIO() - torch.save(audio_embedding, buffer) - buffer.seek(0) - binary_data = buffer.read() - base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8") + base64_audio_embedding = tensor2base64(audio_embedding) conversation, mm_future, mm_uuids = parse_chat_messages_futures( [ @@ -971,7 +900,6 @@ async def test_parse_chat_messages_audio_embeds_async( } ], audio_embeds_model_config, - qwen2_audio_tokenizer, content_format="string", ) @@ -990,7 +918,6 @@ async def test_parse_chat_messages_audio_embeds_async( @pytest.mark.asyncio async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( phi3v_model_config_image_embeds, - phi3v_tokenizer, ): uuid = "abcd" conversation, mm_future, mm_uuids = parse_chat_messages_futures( @@ -1004,7 +931,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( } ], phi3v_model_config_image_embeds, - phi3v_tokenizer, content_format="string", ) @@ -1024,7 +950,6 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_async( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_future, mm_uuids = parse_chat_messages_futures( @@ -1042,7 +967,6 @@ async def test_parse_chat_messages_multiple_images_async( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1058,7 +982,6 @@ async def test_parse_chat_messages_multiple_images_async( def test_parse_chat_messages_placeholder_already_in_prompt( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1076,7 +999,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) assert conversation == [ @@ -1091,7 +1013,6 @@ def test_parse_chat_messages_placeholder_already_in_prompt( def test_parse_chat_messages_placeholder_one_already_in_prompt( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1110,7 +1031,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1127,7 +1047,6 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt( def test_parse_chat_messages_multiple_images_across_messages( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1149,7 +1068,6 @@ def test_parse_chat_messages_multiple_images_across_messages( }, ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1164,7 +1082,6 @@ def test_parse_chat_messages_multiple_images_across_messages( def test_parse_chat_messages_multiple_images_with_uuids_across_messages( phi3v_model_config, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -1195,7 +1112,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages( }, ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1210,7 +1126,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_across_messages( def test_parse_chat_messages_context_text_format( phi3v_model_config, - phi3v_tokenizer, ): conversation, mm_data, mm_uuids = parse_chat_messages( [ @@ -1222,7 +1137,6 @@ def test_parse_chat_messages_context_text_format( {"role": "user", "content": "What about this one?"}, ], phi3v_model_config, - phi3v_tokenizer, content_format="openai", ) @@ -1246,7 +1160,6 @@ def test_parse_chat_messages_context_text_format( def test_parse_chat_messages_rejects_too_many_images_in_one_message( phi3v_model_config, - phi3v_tokenizer, image_url, ): with warnings.catch_warnings(): @@ -1277,14 +1190,12 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) def test_parse_chat_messages_rejects_too_many_images_across_messages( phi3v_model_config, - phi3v_tokenizer, image_url, ): with warnings.catch_warnings(): @@ -1322,14 +1233,12 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( }, ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) def test_parse_chat_messages_multiple_images_uncommon_input( phi3v_model_config, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1344,7 +1253,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input( } ], phi3v_model_config, - phi3v_tokenizer, content_format="string", ) @@ -1360,7 +1268,6 @@ def test_parse_chat_messages_multiple_images_uncommon_input( def test_parse_chat_messages_multiple_images_interleave( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1380,7 +1287,6 @@ def test_parse_chat_messages_multiple_images_interleave( } ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1398,7 +1304,6 @@ def test_parse_chat_messages_multiple_images_interleave( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_interleave_async( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages_futures( @@ -1418,7 +1323,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async( } ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1436,7 +1340,6 @@ async def test_parse_chat_messages_multiple_images_interleave_async( @pytest.mark.asyncio async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -1465,7 +1368,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( } ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1482,7 +1384,6 @@ async def test_parse_chat_messages_multiple_images_with_uuids_interleave_async( def test_parse_chat_messages_multiple_images_multiple_messages_interleave( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): conversation, mm_data, mm_uuids = parse_chat_messages( @@ -1505,7 +1406,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( }, ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1523,7 +1423,6 @@ def test_parse_chat_messages_multiple_images_multiple_messages_interleave( def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interleave( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): image_uuid = str(hash(image_url)) @@ -1555,7 +1454,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl }, ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -1573,7 +1471,6 @@ def test_parse_chat_messages_multiple_images_with_uuids_multiple_messages_interl def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, image_url, video_url, audio_url, @@ -1601,7 +1498,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( }, ], qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, content_format="string", ) @@ -1627,7 +1523,6 @@ def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interleave( qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, image_url, video_url, audio_url, @@ -1671,7 +1566,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl }, ], qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, content_format="string", ) @@ -1699,7 +1593,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_messages_interl def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_messages_interleave( # noqa: E501 qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, image_url, video_url, audio_url, @@ -1743,7 +1636,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes }, ], qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, content_format="string", ) @@ -1775,7 +1667,6 @@ def test_parse_chat_messages_multiple_modals_with_uuids_multiple_empty_media_mes def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_messages_interleave( # noqa: E501 qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, image_url, video_url, audio_url, @@ -1811,7 +1702,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message }, ], qwen25omni_model_config_mm_interleaved, - qwen25omni_tokenizer, content_format="string", ) @@ -1837,7 +1727,6 @@ def test_parse_chat_messages_multiple_modals_with_partial_uuids_multiple_message def test_parse_chat_messages_multiple_images_interleave_with_placeholders( phi3v_model_config_mm_interleaved, - phi3v_tokenizer, image_url, ): with pytest.raises( @@ -1861,7 +1750,6 @@ def test_parse_chat_messages_multiple_images_interleave_with_placeholders( } ], phi3v_model_config_mm_interleaved, - phi3v_tokenizer, content_format="string", ) @@ -2237,9 +2125,7 @@ def test_resolve_content_format_examples(template_path, expected_format): assert resolved_format == expected_format -def test_parse_chat_messages_include_thinking_chunk( - mistral_model_config, mistral_tokenizer -): +def test_parse_chat_messages_include_thinking_chunk(mistral_model_config): messages = [ { "role": "system", @@ -2269,7 +2155,6 @@ def test_parse_chat_messages_include_thinking_chunk( conversation_with_thinking, _, _ = parse_chat_messages( messages, mistral_model_config, - mistral_tokenizer, content_format="openai", ) @@ -2353,7 +2238,6 @@ def test_apply_mistral_chat_template_thinking_chunk(): def test_parse_chat_messages_single_empty_audio_with_uuid( qwen2_audio_model_config, - qwen2_audio_tokenizer, ): audio_uuid = "abcd" conversation, mm_data, mm_uuids = parse_chat_messages( @@ -2371,7 +2255,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid( } ], qwen2_audio_model_config, - qwen2_audio_tokenizer, content_format="string", ) @@ -2389,7 +2272,6 @@ def test_parse_chat_messages_single_empty_audio_with_uuid( @pytest.mark.asyncio async def test_parse_chat_messages_single_empty_audio_with_uuid_async( qwen2_audio_model_config, - qwen2_audio_tokenizer, ): audio_uuid = "abcd" conversation, mm_future, mm_uuids = parse_chat_messages_futures( @@ -2407,7 +2289,6 @@ async def test_parse_chat_messages_single_empty_audio_with_uuid_async( } ], qwen2_audio_model_config, - qwen2_audio_tokenizer, content_format="string", ) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index d79fdfbe07af3..99b168dc75548 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( - BatchedTritonOrDeepGemmExperts, -) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): needs_matching_quant=False, needs_deep_gemm=True, ) - register_experts( - BatchedTritonOrDeepGemmExperts, - batched_format, - common_float_and_int_types, - blocked_quantization_support=True, - supports_chunking=False, - supports_expert_map=False, - needs_matching_quant=True, - needs_deep_gemm=True, - ) register_experts( TritonOrDeepGemmExperts, standard_format, @@ -457,10 +444,6 @@ def make_fused_experts( kwargs = batch_kwargs | quant_kwargs print(f"Making BatchedTritonExperts {kwargs} ...") experts = BatchedTritonExperts(**kwargs) - elif fused_experts_type == BatchedTritonOrDeepGemmExperts: - kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs - print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") - experts = BatchedTritonOrDeepGemmExperts(**kwargs) elif fused_experts_type == DeepGemmExperts: print(f"Making DeepGemmExperts {quant_config} ...") experts = DeepGemmExperts(quant_config) diff --git a/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py b/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py new file mode 100644 index 0000000000000..e4617072cd52c --- /dev/null +++ b/tests/kernels/moe/test_silu_mul_per_token_group_quant_fp8_colmajor.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _per_token_group_quant_fp8_colmajor, + silu_mul_per_token_group_quant_fp8_colmajor, +) +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used + +FLOAT8_DTYPE = torch.float8_e4m3fn +GROUP_SIZE = 128 + + +def reference_quant(x: torch.Tensor, use_ue8m0: bool): + """ + Reference triton quant kernel from, + vllm.model_executor.layers.quantization.utils.fp8_utils + """ + + x_q = torch.empty_like(x, device=x.device, dtype=FLOAT8_DTYPE) + + # Allocate the scale tensor in column-major format. + shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + + M = x.numel() // GROUP_SIZE + N = GROUP_SIZE + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + + finfo = torch.finfo(FLOAT8_DTYPE) + fp8_min = finfo.min + fp8_max = finfo.max + + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + GROUP_SIZE, + x.shape[1], + x.stride(0), + x_s.stride(1), + eps=1e-10, + fp8_min=fp8_min, + fp8_max=fp8_max, + use_ue8m0=use_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + return x_q, x_s + + +def reference(x: torch.Tensor, use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]: + T, N = x.size() + ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda") + torch.ops._C.silu_and_mul(ref_act_out, x) + return reference_quant(ref_act_out, use_ue8m0) + + +@pytest.mark.parametrize("T", [128, 256, 512]) +@pytest.mark.parametrize("N", [128 * 2, 256 * 2, 768 * 2, 2048 * 2, 7168 * 2]) +def test_silu_mul_fp8_quant_deep_gemm(T: int, N: int): + current_platform.seed_everything(42) + + input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda") + + use_ue8m0 = is_deep_gemm_e8m0_used() + + # Test + output, output_scales = silu_mul_per_token_group_quant_fp8_colmajor( + input, use_ue8m0=use_ue8m0 + ) + + # Reference + ref_output, ref_output_scales = reference(input, use_ue8m0) + + torch.testing.assert_close(output.to(torch.float32), ref_output.to(torch.float32)) + torch.testing.assert_close(output_scales, ref_output_scales) diff --git a/tests/model_executor/model_loader/runai_model_streamer/__init__.py b/tests/model_executor/model_loader/runai_streamer_loader/__init__.py similarity index 100% rename from tests/model_executor/model_loader/runai_model_streamer/__init__.py rename to tests/model_executor/model_loader/runai_streamer_loader/__init__.py diff --git a/tests/model_executor/model_loader/runai_streamer_loader/conftest.py b/tests/model_executor/model_loader/runai_streamer_loader/conftest.py new file mode 100644 index 0000000000000..9a022f6bbd9d1 --- /dev/null +++ b/tests/model_executor/model_loader/runai_streamer_loader/conftest.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.executor import UniProcExecutor +from vllm.v1.worker.worker_base import WorkerWrapperBase + + +# This is a dummy executor for patching in test_runai_model_streamer_s3.py. +# We cannot use vllm_runner fixture here, because it spawns worker process. +# The worker process reimports the patched entities, and the patch is not applied. +class RunaiDummyExecutor(UniProcExecutor): + def _init_executor(self) -> None: + distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) + + local_rank = 0 + rank = 0 + is_driver_worker = True + + device_info = self.vllm_config.device_config.device.__str__().split(":") + if len(device_info) > 1: + local_rank = int(device_info[1]) + + worker_rpc_kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + + wrapper_kwargs = { + "vllm_config": self.vllm_config, + } + + self.driver_worker = WorkerWrapperBase(**wrapper_kwargs) + + self.collective_rpc("init_worker", args=([worker_rpc_kwargs],)) + self.collective_rpc("init_device") diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py b/tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_loader.py similarity index 100% rename from tests/model_executor/model_loader/runai_model_streamer/test_runai_model_streamer_loader.py rename to tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_loader.py diff --git a/tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_s3.py b/tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_s3.py new file mode 100644 index 0000000000000..d60c9ba64cbdb --- /dev/null +++ b/tests/model_executor/model_loader/runai_streamer_loader/test_runai_model_streamer_s3.py @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path + +from huggingface_hub import snapshot_download +from runai_model_streamer.safetensors_streamer.streamer_mock import StreamerPatcher + +from vllm.engine.arg_utils import EngineArgs + +from .conftest import RunaiDummyExecutor + +load_format = "runai_streamer" +test_model = "openai-community/gpt2" + + +def test_runai_model_loader_download_files_s3_mocked_with_patch( + vllm_runner, + tmp_path: Path, + monkeypatch, +): + patcher = StreamerPatcher(str(tmp_path)) + + test_mock_s3_model = "s3://my-mock-bucket/gpt2/" + + # Download model from HF + mock_model_dir = f"{tmp_path}/gpt2" + snapshot_download(repo_id=test_model, local_dir=mock_model_dir) + + monkeypatch.setattr( + "vllm.transformers_utils.runai_utils.runai_list_safetensors", + patcher.shim_list_safetensors, + ) + monkeypatch.setattr( + "vllm.transformers_utils.runai_utils.runai_pull_files", + patcher.shim_pull_files, + ) + monkeypatch.setattr( + "vllm.model_executor.model_loader.weight_utils.SafetensorsStreamer", + patcher.create_mock_streamer, + ) + + engine_args = EngineArgs( + model=test_mock_s3_model, + load_format=load_format, + tensor_parallel_size=1, + ) + + vllm_config = engine_args.create_engine_config() + + executor = RunaiDummyExecutor(vllm_config) + executor.driver_worker.load_model() diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py b/tests/model_executor/model_loader/runai_streamer_loader/test_runai_utils.py similarity index 100% rename from tests/model_executor/model_loader/runai_model_streamer/test_runai_utils.py rename to tests/model_executor/model_loader/runai_streamer_loader/test_runai_utils.py diff --git a/tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py b/tests/model_executor/model_loader/runai_streamer_loader/test_weight_utils.py similarity index 100% rename from tests/model_executor/model_loader/runai_model_streamer/test_weight_utils.py rename to tests/model_executor/model_loader/runai_streamer_loader/test_weight_utils.py diff --git a/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py new file mode 100644 index 0000000000000..c259c532220b2 --- /dev/null +++ b/tests/models/language/pooling/test_all_pooling_plus_chunked_prefill.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import torch +from transformers import AutoModel + +from tests.models.utils import check_embeddings_close +from vllm import TokensPrompt + + +@pytest.mark.parametrize( + "model", + ["Qwen/Qwen3-Embedding-0.6B"], +) +@torch.inference_mode +def test_embed_models(hf_runner, vllm_runner, model: str): + chunk_size = 10 + n_prompt_tokens = [55, 56, 57] + token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] + + with vllm_runner( + model, + runner="pooling", + max_model_len=128, + max_num_batched_tokens=chunk_size, + enforce_eager=True, + # `enable_chunked_prefill`: Set to `False` instead of `None` in VllmRunner + enable_chunked_prefill=True, + enable_prefix_caching=True, + ) as vllm_model: + vllm_outputs = vllm_model.token_embed( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + ) + + with hf_runner( + model, + auto_cls=AutoModel, + ) as hf_model: + hf_outputs = [] + for token_prompt in token_prompts: + inputs = hf_model.wrap_device({"input_ids": torch.tensor([token_prompt])}) + input_ids = inputs["input_ids"] + output = hf_model.model(input_ids) + hf_outputs.append(output.last_hidden_state.cpu().float()[0]) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + check_embeddings_close( + embeddings_0_lst=hf_output, + embeddings_1_lst=vllm_output, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index 0d41b93233d5a..488b27e2da0f1 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -20,7 +20,6 @@ def test_extract_hidden_states(hf_runner, vllm_runner, model: str): max_model_len=128, enforce_eager=True, runner="pooling", - enable_chunked_prefill=False, enable_prefix_caching=True, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( diff --git a/tests/models/multimodal/generation/conftest.py b/tests/models/multimodal/generation/conftest.py index ee3ecdb10fdb8..26f8586742cea 100644 --- a/tests/models/multimodal/generation/conftest.py +++ b/tests/models/multimodal/generation/conftest.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Pytest configuration for vLLM tests.""" +import warnings + import torch from vllm.platforms import current_platform @@ -14,6 +16,20 @@ def pytest_configure(config): if not current_platform.is_rocm(): return + skip_patterns = ["test_granite_speech.py"] + if any(pattern in str(arg) for arg in config.args for pattern in skip_patterns): + # Skip disabling SDP for Granite Speech tests on ROCm + return + + # Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers + # accuracy issues + # TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_math_sdp(True) + warnings.warn( + "ROCm: Disabled flash_sdp and mem_efficient_sdp, enabled math_sdp " + "to avoid HuggingFace Transformers accuracy issues", + UserWarning, + stacklevel=1, + ) diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 0eaf7198f91b7..f896126a49089 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -403,12 +403,13 @@ VLM_TEST_SETTINGS = { # So, we need to reduce the number of tokens for the test to pass. max_tokens=8, num_logprobs=10, + auto_cls=AutoModelForCausalLM, marks=[large_gpu_mark(min_gb=32)], ), "glm4_1v": VLMTestInfo( models=["zai-org/GLM-4.1V-9B-Thinking"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|user|>\n{img_prompt}<|assistant|>", + prompt_formatter=lambda img_prompt: f"[gMASK]<|user|>\n{img_prompt}<|assistant|>\n", # noqa: E501 img_idx_to_prompt=lambda idx: "<|begin_of_image|><|image|><|end_of_image|>", video_idx_to_prompt=lambda idx: "<|begin_of_video|><|video|><|end_of_video|>", max_model_len=2048, @@ -423,6 +424,7 @@ VLM_TEST_SETTINGS = { models=["zai-org/GLM-4.1V-9B-Thinking"], # GLM4.1V require include video metadata for input test_type=VLMTestType.CUSTOM_INPUTS, + prompt_formatter=lambda vid_prompt: f"[gMASK]<|user|>\n{vid_prompt}<|assistant|>\n", # noqa: E501 max_model_len=4096, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, @@ -737,7 +739,13 @@ VLM_TEST_SETTINGS = { max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForImageTextToText, - marks=[large_gpu_mark(min_gb=48)], + marks=[ + large_gpu_mark(min_gb=48), + pytest.mark.skipif( + current_platform.is_rocm(), + reason="Model produces a vector of output in HF on ROCm", + ), + ], ), "qwen_vl": VLMTestInfo( models=["Qwen/Qwen-VL"], diff --git a/tests/models/multimodal/generation/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py index e39dfc888779e..f528a993f8551 100644 --- a/tests/models/multimodal/generation/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -8,6 +8,7 @@ from transformers import AutoModelForSpeechSeq2Seq from vllm.logprobs import SampleLogprobs from vllm.lora.request import LoRARequest +from vllm.platforms import current_platform from ....conftest import AudioTestAssets, HfRunner, PromptAudioInput, VllmRunner from ...registry import HF_EXAMPLE_MODELS @@ -34,6 +35,12 @@ audio_lora_path = MODEL_NAME models = [MODEL_NAME] +@pytest.fixture(autouse=True) +def set_attention_backend_for_rocm(monkeypatch): + if current_platform.is_rocm(): + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN") + + def run_test( hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], @@ -111,8 +118,12 @@ def run_test( @pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_model_len", [2048]) +@pytest.mark.parametrize( + "dtype", ["float16"] if current_platform.is_rocm() else ["bfloat16"] +) +@pytest.mark.parametrize( + "max_model_len", [512] if current_platform.is_rocm() else [2048] +) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) def test_models( diff --git a/tests/models/multimodal/generation/test_phi4_multimodal.py b/tests/models/multimodal/generation/test_phi4_multimodal.py deleted file mode 100644 index 62456221711ed..0000000000000 --- a/tests/models/multimodal/generation/test_phi4_multimodal.py +++ /dev/null @@ -1,281 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from collections.abc import Sequence - -import librosa -import pytest -from huggingface_hub import snapshot_download - -from vllm.assets.image import ImageAsset -from vllm.lora.request import LoRARequest -from vllm.multimodal.image import rescale_image_size - -from ....conftest import ( - IMAGE_ASSETS, - HfRunner, - PromptAudioInput, - PromptImageInput, - VllmRunner, -) -from ....utils import large_gpu_test -from ...utils import check_logprobs_close - -HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts( - { - "stop_sign": "<|user|>\n<|image|>\nWhat's the content of the image?<|end|>\n<|assistant|>\n", # noqa: E501 - "cherry_blossom": "<|user|>\n<|image|>\nPlease infer the season with reason in details.<|end|>\n<|assistant|>\n", # noqa: E501 - } -) -HF_MULTIIMAGE_IMAGE_PROMPT = ( - "<|user|>\n<|image|>\n<|image|>\nDescribe these images.<|end|>\n<|assistant|>\n" # noqa: E501 -) - -model_path = snapshot_download( - "microsoft/Phi-4-multimodal-instruct", revision="refs/pr/70" -) -# Since the vision-lora and speech-lora co-exist with the base model, -# we have to manually specify the path of the lora weights. -vision_lora_path = os.path.join(model_path, "vision-lora") -speech_question = os.path.join( - model_path, "examples", "what_is_shown_in_this_image.wav" -) -models = [model_path] - -target_dtype = "half" - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - inputs: Sequence[tuple[list[str], PromptImageInput, PromptAudioInput | None]], - model: str, - *, - max_model_len: int, - dtype: str, - max_tokens: int, - num_logprobs: int, - mm_limit: int, - tensor_parallel_size: int, - distributed_executor_backend: str | None = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test are from IMAGE_ASSETS. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - # max_model_len should be greater than image_feature_size - with vllm_runner( - model, - task="generate", - max_model_len=max_model_len, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_lora=True, - max_lora_rank=320, - gpu_memory_utilization=0.8, # set to 0.8 to avoid OOM in CI - enforce_eager=True, - trust_remote_code=False, - ) as vllm_model: - lora_request = LoRARequest("vision", 1, vision_lora_path) - vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs( - prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - lora_request=lora_request, - ) - for prompts, images, audios in inputs - ] - - with hf_runner(model, dtype=dtype) as hf_model: - hf_model.model.load_adapter( - vision_lora_path, - adapter_name="vision", - ) - hf_processor = hf_model.processor - eos_token_id = hf_processor.tokenizer.eos_token_id - hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit( - prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - audios=audios, - eos_token_id=eos_token_id, - ) - for prompts, images, audios in inputs - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, vllm_outputs_per_case): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [12800]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_models( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors, - dtype: str, - max_model_len: int, - max_tokens: int, - num_logprobs: int, -) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [ - ( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - None, - ) - for image, prompt in zip(images, HF_IMAGE_PROMPTS) - ] - - run_test( - hf_runner, - vllm_runner, - inputs_per_image, - model, - dtype=dtype, - max_model_len=max_model_len, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) - - -@large_gpu_test(min_gb=48) -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize( - "size_factors", - [ - # No image - # [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], - ], -) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [25600]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_multi_images_models( - hf_runner, - vllm_runner, - image_assets, - model, - size_factors, - dtype: str, - max_model_len: int, - max_tokens: int, - num_logprobs: int, -) -> None: - images = [asset.pil_image for asset in image_assets] - - inputs_per_case = [ - ( - [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [ - [rescale_image_size(image, factor) for image in images] - for factor in size_factors - ], - None, - ), - ] - - run_test( - hf_runner, - vllm_runner, - inputs_per_case, - model, - dtype=dtype, - max_model_len=max_model_len, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=2, - tensor_parallel_size=1, - ) - - -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_model_len", [12800]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_vision_speech_models( - hf_runner, - vllm_runner, - model, - dtype: str, - max_model_len: int, - max_tokens: int, - num_logprobs: int, -) -> None: - # use the example speech question so that the model outputs are reasonable - audio = librosa.load(speech_question, sr=16000) - image = ImageAsset("cherry_blossom").pil_image.convert("RGB") - - inputs_vision_speech = [ - ( - ["<|user|><|image|><|audio|><|end|><|assistant|>"], - [image], - [audio], - ), - ] - - run_test( - hf_runner, - vllm_runner, - inputs_vision_speech, - model, - dtype=dtype, - max_model_len=max_model_len, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index 3cad2c43d5623..375099f4365ac 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -15,6 +15,7 @@ from transformers import AutoProcessor from vllm import SamplingParams, TextPrompt, TokensPrompt from vllm.logprobs import Logprob, SampleLogprobs from vllm.multimodal import MultiModalDataBuiltins +from vllm.platforms import current_platform from ....utils import VLLM_PATH, large_gpu_test from ...utils import check_logprobs_close @@ -165,6 +166,15 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs: def test_chat( vllm_runner, max_model_len: int, model: str, dtype: str, local_asset_server ) -> None: + if ( + model == MISTRAL_SMALL_3_1_ID + and max_model_len == 65536 + and current_platform.is_rocm() + ): + pytest.skip( + "OOM on ROCm: 24B model with 65536 context length exceeds GPU memory" + ) + EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT[model]) with vllm_runner( model, diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index 8c9c390911bdc..84109233685bb 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -140,7 +140,7 @@ def video_with_metadata_glm4_1v(): metadata = VIDEO_ASSETS[0].metadata question = "Describe the video." video_prompt = "<|begin_of_video|><|video|><|end_of_video|>" - formatted_prompt = f"<|user|>\n{video_prompt}{question}<|assistant|>\n" + formatted_prompt = f"[gMASK]<|user|>\n{video_prompt}{question}<|assistant|>\n" scales = [0.1, 0.2, 0.25] video_input = [ diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 87cd5c3cd3554..b2c62fbd119cc 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -25,6 +25,7 @@ from transformers import ( from transformers.video_utils import VideoMetadata from vllm.logprobs import SampleLogprobs +from vllm.platforms import current_platform from vllm.utils.collection_utils import is_list_of from .....conftest import HfRunner, ImageAsset, ImageTestAssets @@ -366,6 +367,40 @@ def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOut def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: """Patches and returns an instance of the HfRunner to use for GLM4V.""" + if current_platform.is_rocm(): + import types + + config = hf_model.model.config + if hasattr(config, "num_layers") and not hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = config.num_layers + config.output_hidden_states = True + + def patched_prepare_cache( + self, generation_config, model_kwargs, *args, **kwargs + ): + model_kwargs["past_key_values"] = None + model_kwargs["use_cache"] = False + return model_kwargs + + hf_model.model._prepare_cache_for_generation = types.MethodType( + patched_prepare_cache, hf_model.model + ) + original_generate = hf_model.model.generate + + def patched_generate(*args, **kwargs): + kwargs["output_hidden_states"] = True + kwargs["return_dict_in_generate"] = True + return original_generate(*args, **kwargs) + + hf_model.model.generate = patched_generate + original_forward = hf_model.model.forward + + def patched_forward(*args, **kwargs): + kwargs["output_hidden_states"] = True + return original_forward(*args, **kwargs) + + hf_model.model.forward = patched_forward + hf_processor = hf_model.processor def processor(*args, text="", images=None, **kwargs): @@ -406,7 +441,15 @@ def glm4_1v_patch_hf_runner(hf_model: HfRunner) -> HfRunner: if videos is not None and is_list_of(videos, tuple): # If videos is a list of tuples, we assume each tuple contains # (video_array, metadata) as in the case of GLM4.1V. - video_metadata = [[VideoMetadata(**video[1])] for video in videos] + # Filter out 'do_sample_frames' as it's not a valid VideoMetadata arg + video_metadata = [ + [ + VideoMetadata( + **{k: v for k, v in video[1].items() if k != "do_sample_frames"} + ) + ] + for video in videos + ] videos = [[video[0]] for video in videos] else: video_metadata = None diff --git a/tests/models/multimodal/pooling/conftest.py b/tests/models/multimodal/pooling/conftest.py new file mode 100644 index 0000000000000..c5f40cb42ca2a --- /dev/null +++ b/tests/models/multimodal/pooling/conftest.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Pytest configuration for vLLM pooling tests.""" + +import os +import warnings + +from vllm.platforms import current_platform + + +def pytest_collection_modifyitems(config, items): + """Set FLEX_ATTENTION backend for SigLIP tests on ROCm.""" + if not current_platform.is_rocm(): + return + + siglip_tests = [item for item in items if "test_siglip" in item.nodeid] + + if siglip_tests: + os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION" + warnings.warn( + "ROCm: Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION for SigLIP tests", + UserWarning, + stacklevel=1, + ) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 8ef1fba8df3e3..6b9d388f2b9b4 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -396,28 +396,6 @@ def test_processing_correctness( ) -# Phi4MultimodalForCausalLM share same model repo with original format -# Phi4MMForCausalLM, so we add it as a separate test case -# Remove this test after conversion PR merged: -# https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/70 -@pytest.mark.parametrize("model_arch", ["Phi4MultimodalForCausalLM"]) -@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0]) -@pytest.mark.parametrize("num_batches", [32]) -@pytest.mark.parametrize("simplify_rate", [1.0]) -def test_processing_correctness_phi4_multimodal( - model_arch: str, - hit_rate: float, - num_batches: int, - simplify_rate: float, -): - _test_processing_correctness( - model_arch, - hit_rate=hit_rate, - num_batches=num_batches, - simplify_rate=simplify_rate, - ) - - def _assert_inputs_equal( a: MultiModalInputs, b: MultiModalInputs, diff --git a/tests/models/registry.py b/tests/models/registry.py index 6b1d24b1c99b5..b9f9945eb5fb8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -667,6 +667,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { "moonshotai/Kimi-VL-A3B-Instruct", extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, trust_remote_code=True, + max_transformers_version="4.53.3", + transformers_version_reason="HF model uses deprecated transformers API " + "(PytorchGELUTanh, DynamicCache.seen_tokens, and more). See: " + "https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/discussions/31", ), "LightOnOCRForConditionalGeneration": _HfExamplesInfo( "lightonai/LightOnOCR-1B", @@ -767,10 +771,6 @@ _MULTIMODAL_EXAMPLE_MODELS = { "Phi4MMForCausalLM": _HfExamplesInfo( "microsoft/Phi-4-multimodal-instruct", trust_remote_code=True ), - "Phi4MultimodalForCausalLM": _HfExamplesInfo( - "microsoft/Phi-4-multimodal-instruct", - revision="refs/pr/70", - ), "PixtralForConditionalGeneration": _HfExamplesInfo( "mistralai/Pixtral-12B-2409", extras={ diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py index d31b1c7d169b7..34e9483de54b3 100644 --- a/tests/reasoning/test_base_thinking_reasoning_parser.py +++ b/tests/reasoning/test_base_thinking_reasoning_parser.py @@ -112,7 +112,7 @@ class TestBaseThinkingReasoningParserMethods: """Test the is_reasoning_end method.""" parser = TestThinkingReasoningParser(test_tokenizer) end_token_id = parser.end_token_id - + start_token_id = parser.start_token_id # Test with end token present assert parser.is_reasoning_end([1, 2, end_token_id, 4]) is True @@ -122,6 +122,16 @@ class TestBaseThinkingReasoningParserMethods: # Test with empty list assert parser.is_reasoning_end([]) is False + # Test with interleaved thinking + assert parser.is_reasoning_end([1, start_token_id, 2, end_token_id]) is True + assert parser.is_reasoning_end([1, start_token_id, 2, 3]) is False + assert ( + parser.is_reasoning_end( + [1, start_token_id, 2, end_token_id, 2, 2, start_token_id] + ) + is False + ) + def test_extract_content_ids(self, test_tokenizer): """Test the extract_content_ids method.""" parser = TestThinkingReasoningParser(test_tokenizer) diff --git a/tests/standalone_tests/python_only_compile.sh b/tests/standalone_tests/python_only_compile.sh index 7cc5ef6596490..d29b9afcc6fbf 100644 --- a/tests/standalone_tests/python_only_compile.sh +++ b/tests/standalone_tests/python_only_compile.sh @@ -5,6 +5,10 @@ set -e set -x +merge_base_commit=$(git merge-base HEAD origin/main) +echo "Current merge base commit with main: $merge_base_commit" +git show --oneline -s $merge_base_commit + cd /vllm-workspace/ # uninstall vllm @@ -18,7 +22,7 @@ apt autoremove -y echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py -VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . +VLLM_PRECOMPILED_WHEEL_COMMIT=$merge_base_commit VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . # Run the script python3 -c 'import vllm' diff --git a/tests/test_config.py b/tests/test_config.py index 019c0d6d8733f..203447cd531fb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -629,8 +629,8 @@ def test_s3_url_different_models_create_different_directories(mock_pull_files): ( "internlm/internlm2-1_8b-reward", "decoder", - False, - "Pooling models with all pooling does not support chunked prefill.", + True, + "Pooling models with causal attn and all pooling support chunked prefill.", ), ( "BAAI/bge-base-en", @@ -748,8 +748,8 @@ def test_is_chunked_prefill_supported( ( "internlm/internlm2-1_8b-reward", "decoder", - False, - "Pooling models with all pooling does not support prefix caching.", + True, + "Pooling models with causal attn and all pooling support prefix caching.", ), ( "BAAI/bge-base-en", diff --git a/tests/test_envs.py b/tests/test_envs.py index 6a9835a68e7e2..11bbec38202bf 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -365,3 +365,54 @@ class TestEnvSetWithChoices: with patch.dict(os.environ, {"TEST_ENV": "option1,option1,option2"}): env_func = env_set_with_choices("TEST_ENV", [], ["option1", "option2"]) assert env_func() == {"option1", "option2"} + + +class TestVllmConfigureLogging: + """Test cases for VLLM_CONFIGURE_LOGGING environment variable.""" + + def test_configure_logging_defaults_to_true(self): + """Test that VLLM_CONFIGURE_LOGGING defaults to True when not set.""" + # Ensure the env var is not set + with patch.dict(os.environ, {}, clear=False): + if "VLLM_CONFIGURE_LOGGING" in os.environ: + del os.environ["VLLM_CONFIGURE_LOGGING"] + + # Clear cache if it exists + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + result = envs.VLLM_CONFIGURE_LOGGING + assert result is True + assert isinstance(result, bool) + + def test_configure_logging_with_zero_string(self): + """Test that VLLM_CONFIGURE_LOGGING='0' evaluates to False.""" + with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "0"}): + # Clear cache if it exists + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + result = envs.VLLM_CONFIGURE_LOGGING + assert result is False + assert isinstance(result, bool) + + def test_configure_logging_with_one_string(self): + """Test that VLLM_CONFIGURE_LOGGING='1' evaluates to True.""" + with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "1"}): + # Clear cache if it exists + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + result = envs.VLLM_CONFIGURE_LOGGING + assert result is True + assert isinstance(result, bool) + + def test_configure_logging_with_invalid_value_raises_error(self): + """Test that invalid VLLM_CONFIGURE_LOGGING value raises ValueError.""" + with patch.dict(os.environ, {"VLLM_CONFIGURE_LOGGING": "invalid"}): + # Clear cache if it exists + if hasattr(envs.__getattr__, "cache_clear"): + envs.__getattr__.cache_clear() + + with pytest.raises(ValueError, match="invalid literal for int"): + _ = envs.VLLM_CONFIGURE_LOGGING diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py new file mode 100644 index 0000000000000..e5deb7f40eb35 --- /dev/null +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -0,0 +1,847 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Generator + +import partial_json_parser +import pytest +from mistral_common.protocol.instruct.messages import AssistantMessage +from mistral_common.protocol.instruct.request import InstructRequest +from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall +from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolParser +from vllm.tokenizers import ( + MistralTokenizer, + TokenizerLike, + get_tokenizer, +) +from vllm.tokenizers.detokenizer_utils import detokenize_incrementally + + +@pytest.fixture(scope="module") +def mistral_pre_v11_tokenizer(): + MODEL = "mistralai/Mistral-7B-Instruct-v0.3" + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" + return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral") + + +@pytest.fixture +def mistral_pre_v11_tool_parser(mistral_pre_v11_tokenizer): + return MistralToolParser(mistral_pre_v11_tokenizer) + + +@pytest.fixture +def mistral_tool_parser(mistral_tokenizer): + return MistralToolParser(mistral_tokenizer) + + +def assert_tool_calls( + actual_tool_calls: list[ToolCall] | list[DeltaToolCall], + expected_tool_calls: list[ToolCall], +): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip( + actual_tool_calls, expected_tool_calls + ): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) == 9 + + if isinstance(actual_tool_call, ToolCall): + assert actual_tool_call.type == "function" + elif isinstance(actual_tool_call, DeltaToolCall): + assert actual_tool_call.function is not None + assert actual_tool_call.function.name is not None + assert actual_tool_call.function.arguments is not None + assert actual_tool_call.function is not None + assert actual_tool_call.function.name == expected_tool_call.function.name, ( + f"got wrong function name:${actual_tool_call.function.name}" + ) + assert ( + actual_tool_call.function.arguments == expected_tool_call.function.arguments + ), f"got wrong function argument:${actual_tool_call.function.arguments}" + + +def fix_tool_call_tokenization( + tokens: list[int], + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: TokenizerLike, +): + """ + Replaces the textual token sequence for [TOOL_CALLS] + with its single special token ID. + """ + textual_tool_call_token_ids = mistral_tokenizer.encode( + text=mistral_tool_parser.bot_token, + add_special_tokens=False, + ) + # textual_tool_call_token_ids must not contain special tokens like bos, eos etc + special_tool_call_token_ids = [mistral_tool_parser.bot_token_id] + + # If the input is too short to contain the sequence, no replacement is possible + if not tokens or len(tokens) < len(textual_tool_call_token_ids): + return tokens + + result_tokens = [] + i = 0 + target_len = len(textual_tool_call_token_ids) + + while i < len(tokens): + # Check if the slice from the current position matches the target sequence + if tokens[i : i + target_len] == textual_tool_call_token_ids: + # If it matches, add the replacement and jump the index forward + result_tokens.extend(special_tool_call_token_ids) + i += target_len + else: + # Otherwise, just add the current token and move to the next one + result_tokens.append(tokens[i]) + i += 1 + + return result_tokens + + +def stream_delta_message_generator( + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: TokenizerLike, + model_output: str | None, + tools: list[tuple[str, str]] | None, +) -> Generator[DeltaMessage, None, None]: + if ( + isinstance(mistral_tokenizer, MistralTokenizer) + and mistral_tokenizer.version >= 11 + ): + # With the newer versions of the tokenizer, + # we cannot tokenize free text + # so we need to create a list of messages to get tokenized + assert tools is not None + assistant_msg = AssistantMessage( + tool_calls=[ + ToolCall( + function=FunctionCall( + name=name, + arguments=arg, + ) + ) + for (name, arg) in tools + ], + ) + request = InstructRequest( + messages=[assistant_msg], + ) + all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens + else: + # Older versions of the tokenizer are + # able to encode directly the model's output (free text) into tokens + assert model_output is not None + all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False) + + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_tool_parser, mistral_tokenizer + ) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[: i + 1] + + (new_tokens, delta_text, new_prefix_offset, new_read_offset) = ( + detokenize_incrementally( + tokenizer=mistral_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=isinstance(mistral_tokenizer, MistralTokenizer), + spaces_between_special_tokens=True, + ) + ) + + current_text = previous_text + delta_text + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=None, # type: ignore[arg-type] + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = ( + previous_tokens + new_tokens if previous_tokens else new_tokens + ) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(mistral_pre_v11_tool_parser): + model_output = "This is a test" + extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + None, + ), + ], +) +def test_extract_tool_calls_pre_v11_tokenizer( + mistral_pre_v11_tool_parser, model_output, expected_tool_calls, expected_content +): + extracted_tool_calls = mistral_pre_v11_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "multiple_tool_calls", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + None, + ), + ( + """[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="multiply", arguments=json.dumps({"a": 3, "b": 6}) + ) + ), + ], + None, + ), + ], +) +def test_extract_tool_calls( + mistral_tool_parser, model_output, expected_tool_calls, expected_content +): + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None + ) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def _test_extract_tool_calls_streaming( + tool_parser, tokenizer, model_output, tools, expected_tool_calls, expected_content +): + other_content: str = "" + function_names: list[str] = [] + function_args_strs: list[str] = [] + tool_call_idx: int = -1 + tool_call_ids: list[str | None] = [] + + for delta_message in stream_delta_message_generator( + tool_parser, tokenizer, model_output, tools + ): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + assert len(tool_parser.prev_tool_call_arr) > 0 + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall( + id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR + ), + ), + ) + for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs + ) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""This is a test""", [], """This is a test"""), + ( + """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": "3", "b": "4"}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming_pre_v11_tokenizer( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + None, + expected_tool_calls, + expected_content, + ) + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_add_strings", + "multiple_tools", + ], + argnames=["tools", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + [("add", '{"a": 3, "b": 4}')], + # [TOOL_CALLS]add{"a": 3, "b": 4} + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + [("add_two_strings", '{"a": "3", "b": "4"}')], + # [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"} + [ + ToolCall( + function=FunctionCall( + name="add_two_strings", + arguments=json.dumps({"a": "3", "b": "4"}), + ) + ) + ], + "", + ), + ( + [ + ("add", '{"a": 3.5, "b": 4}'), + ( + "get_current_weather", + '{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501 + ), + ], + # [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming( + mistral_tool_parser, + mistral_tokenizer, + tools, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_tool_parser, + mistral_tokenizer, + None, + tools, + expected_tool_calls, + expected_content, + ) + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "multiple_tool_calls", + "content_before_tool", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="multiply", arguments=json.dumps({"a": 3, "b": 6}) + ) + ), + ], + "", + ), + ( + # Additional content should not be after the tool calls + """bla[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({"a": 3.5, "b": 4}), + ) + ) + ], + "bla", + ), + ], +) +def test_extract_tool_calls_streaming_one_chunk( + mistral_tool_parser, + mistral_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + if isinstance(mistral_tokenizer, MistralTokenizer): + all_token_ids = mistral_tokenizer.encode(model_output) + else: + all_token_ids = mistral_tokenizer.encode(model_output, add_special_tokens=False) + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_tool_parser, mistral_tokenizer + ) + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=model_output, + delta_text=model_output, + previous_token_ids=[], + current_token_ids=all_token_ids, + delta_token_ids=all_token_ids, + request=None, + ) # type: ignore[arg-type] + assert isinstance(delta_message, DeltaMessage) + assert len(delta_message.tool_calls) == len(expected_tool_calls) + + assert_tool_calls(delta_message.tool_calls, expected_tool_calls) + + if delta_message.content is None: + assert expected_content == "" + else: + assert delta_message.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""This is a test""", [], """This is a test"""), + ( + """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3, "b": 4}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": "3", "b": "4"}) + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + } + ), + ) + ) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"a": 3.5, "b": 4}, "name": "add"}, {"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall( + function=FunctionCall( + name="add", arguments=json.dumps({"a": 3.5, "b": 4}) + ) + ), + ToolCall( + function=FunctionCall( + name="get_current_weather", + arguments=json.dumps( + {"city": "San Francisco", "state": "CA", "unit": "celsius"} + ), + ) + ), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk( + mistral_pre_v11_tool_parser, + mistral_pre_v11_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + if isinstance(mistral_pre_v11_tokenizer, MistralTokenizer): + all_token_ids = mistral_pre_v11_tokenizer.encode(model_output) + else: + all_token_ids = mistral_pre_v11_tokenizer.encode( + model_output, add_special_tokens=False + ) + all_token_ids = fix_tool_call_tokenization( + all_token_ids, mistral_pre_v11_tool_parser, mistral_pre_v11_tokenizer + ) + + delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=model_output, + delta_text=model_output, + previous_token_ids=[], + current_token_ids=all_token_ids, + delta_token_ids=all_token_ids, + request=None, + ) # type: ignore[arg-type] + assert isinstance(delta_message, DeltaMessage) + assert len(delta_message.tool_calls) == len(expected_tool_calls) + + assert_tool_calls(delta_message.tool_calls, expected_tool_calls) + + if delta_message.content is None: + assert expected_content == "" + else: + assert delta_message.content == expected_content diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 7584b903156b7..de7284a309c53 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -123,7 +123,7 @@ CONFIGS: dict[str, ServerConfig] = { "supports_parallel": True, "extended": True, }, - "mistral": { + "mistral-7b": { "model": "mistralai/Mistral-7B-Instruct-v0.3", "arguments": [ "--enforce-eager", @@ -145,6 +145,32 @@ CONFIGS: dict[str, ServerConfig] = { "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally.", + "supports_parallel": True, + }, + "mistral-small-3.2": { + "model": "mistralai/Mistral-Small-3.2-24B-Instruct-2506", + "arguments": [ + "--enforce-eager", + "--no-enable-prefix-caching", + "--tool-call-parser", + "mistral", + "--tokenizer-mode", + "mistral", + "--config-format", + "mistral", + "--load-format", + "mistral", + "--tensor-parallel-size", + "4", + '--ignore-patterns="consolidated.safetensors"', + ], + "system_prompt": "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally.", + "supports_parallel": True, + "extended": True, }, # FIXME: This test currently fails, need to debug why. # "granite20b": { diff --git a/tests/v1/core/test_reset_prefix_cache_e2e.py b/tests/v1/core/test_reset_prefix_cache_e2e.py index e543c30a156ec..083fc3f34f545 100644 --- a/tests/v1/core/test_reset_prefix_cache_e2e.py +++ b/tests/v1/core/test_reset_prefix_cache_e2e.py @@ -11,7 +11,9 @@ PROMPTS = [ ] -def test_reset_prefix_cache_e2e(): +def test_reset_prefix_cache_e2e(monkeypatch): + # "spawn" is required for test to be deterministic + monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") engine_args = EngineArgs( model="Qwen/Qwen3-0.6B", gpu_memory_utilization=0.2, diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index e96759ed66a79..527a56ff49eec 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -9,6 +9,7 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser +from vllm.utils.hashing import _xxhash def test_prefix_caching_from_cli(): @@ -48,6 +49,21 @@ def test_prefix_caching_from_cli(): args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"]) +@pytest.mark.skipif(_xxhash is None, reason="xxhash not installed") +def test_prefix_caching_xxhash_from_cli(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + + # set hash algorithm to xxhash (pickle) + args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash"]) + vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() + assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash" + + # set hash algorithm to xxhash_cbor + args = parser.parse_args(["--prefix-caching-hash-algo", "xxhash_cbor"]) + vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config() + assert vllm_config.cache_config.prefix_caching_hash_algo == "xxhash_cbor" + + def test_defaults_with_usage_context(): engine_args = EngineArgs(model="facebook/opt-125m") vllm_config: VllmConfig = engine_args.create_engine_config(UsageContext.LLM_CLASS) diff --git a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py index 276de2ff8e2cd..b30556fbc81fb 100644 --- a/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py +++ b/tests/v1/entrypoints/openai/test_completion_with_image_embeds.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import base64 -import io import json import openai # use the official client for correctness check @@ -13,6 +11,7 @@ from transformers import AutoConfig from tests.conftest import ImageTestAssets from tests.utils import RemoteOpenAIServer +from vllm.utils.serial_utils import tensor2base64 # any model with a chat template should work here MODEL_NAME = "llava-hf/llava-1.5-7b-hf" @@ -50,18 +49,6 @@ async def client_with_image_embeds(server_with_image_embeds): yield async_client -def encode_image_embedding_to_base64(image_embedding) -> str: - """ - Encode image embedding to base64 string - """ - buffer = io.BytesIO() - torch.save(image_embedding, buffer) - buffer.seek(0) - binary_data = buffer.read() - base64_image_embedding = base64.b64encode(binary_data).decode("utf-8") - return base64_image_embedding - - @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("dtype", [torch.half, torch.float16, torch.float32]) @@ -73,7 +60,7 @@ async def test_completions_with_image_embeds( ): # Test case: Single image embeds input image_embeds = image_assets[0].image_embeds.to(dtype=dtype) - base64_image_embedding = encode_image_embedding_to_base64(image_embeds) + base64_image_embedding = tensor2base64(image_embeds) chat_completion = await client_with_image_embeds.chat.completions.create( messages=[ {"role": "system", "content": "You are a helpful assistant."}, diff --git a/tests/v1/kv_connector/unit/test_shared_storage_connector.py b/tests/v1/kv_connector/unit/test_shared_storage_connector.py index e7013a794a8c6..ff4697a978255 100644 --- a/tests/v1/kv_connector/unit/test_shared_storage_connector.py +++ b/tests/v1/kv_connector/unit/test_shared_storage_connector.py @@ -3,12 +3,14 @@ from dataclasses import asdict from typing import NamedTuple +import pytest from PIL import Image from vllm import LLM, EngineArgs, SamplingParams from vllm.assets.image import ImageAsset from vllm.config import KVTransferConfig from vllm.multimodal.utils import encode_image_base64 +from vllm.platforms import current_platform MODEL_NAME = "RedHatAI/Qwen2.5-VL-3B-Instruct-quantized.w8a8" @@ -108,6 +110,13 @@ def process_prompt(processor, llm: LLM, question: str, image_urls: list[Image]): print("-" * 50) +@pytest.mark.skipif( + current_platform.is_rocm(), + reason=( + "hipErrorLaunchFailure when running this test, see issue:" + "https://github.com/ROCm/pytorch/issues/2822" + ), +) def test_shared_storage_connector_hashes(tmp_path): """ Tests that SharedStorageConnector saves KV to the storage locations diff --git a/tests/v1/logits_processors/test_custom_offline.py b/tests/v1/logits_processors/test_custom_offline.py index 1899737737f4b..e3ddb6138cfdd 100644 --- a/tests/v1/logits_processors/test_custom_offline.py +++ b/tests/v1/logits_processors/test_custom_offline.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random -import sys from typing import Any import pytest @@ -10,7 +9,6 @@ from tests.utils import create_new_process_for_each_test from tests.v1.logits_processors.utils import ( DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_FQCN, - DUMMY_LOGITPROC_MODULE, MAX_TOKENS, MODEL_NAME, POOLING_MODEL_NAME, @@ -18,7 +16,6 @@ from tests.v1.logits_processors.utils import ( CustomLogitprocSource, DummyLogitsProcessor, WrappedPerReqLogitsProcessor, - dummy_module, prompts, ) from tests.v1.logits_processors.utils import entry_points as fake_entry_points @@ -162,8 +159,6 @@ def test_custom_logitsprocs(monkeypatch, logitproc_source: CustomLogitprocSource kwargs: dict[str, list[str | type[LogitsProcessor]]] = {} if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN: # Scenario: load logitproc based on fully-qualified class name (FQCN) - # Inject dummy module which defines logitproc - sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN] elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS: # Scenario: load logitproc from provided class object diff --git a/tests/v1/logits_processors/test_custom_online.py b/tests/v1/logits_processors/test_custom_online.py index 3e0bb02ed68be..3dc6b89790157 100644 --- a/tests/v1/logits_processors/test_custom_online.py +++ b/tests/v1/logits_processors/test_custom_online.py @@ -14,11 +14,9 @@ from tests.utils import RemoteOpenAIServerCustom, create_new_process_for_each_te from tests.v1.logits_processors.utils import ( DUMMY_LOGITPROC_ARG, DUMMY_LOGITPROC_FQCN, - DUMMY_LOGITPROC_MODULE, MAX_TOKENS, MODEL_NAME, TEMP_GREEDY, - dummy_module, prompts, ) from tests.v1.logits_processors.utils import entry_points as fake_entry_points @@ -47,20 +45,14 @@ def _server_with_logitproc_entrypoint( main.main() -def _server_with_logitproc_module( +def _server_with_logitproc_fqcn( env_dict: dict[str, str] | None, model: str, vllm_serve_args: list[str], ) -> None: """Start vLLM server, inject module with dummy logitproc""" - - # Patch `modules` to inject dummy logitproc module from vllm.entrypoints.cli import main - sys.modules[DUMMY_LOGITPROC_MODULE] = dummy_module - - # fork is required for workers to see entrypoint patch - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "fork" if env_dict is not None: os.environ.update(env_dict) @@ -99,7 +91,7 @@ def server(default_server_args, request, monkeypatch): if request.param: # Launch server, append FQCN argument, inject dummy logitproc module args = default_server_args + request.param - _server_fxn = _server_with_logitproc_module + _server_fxn = _server_with_logitproc_fqcn else: # Launch server, inject dummy logitproc entrypoint args = default_server_args diff --git a/tests/v1/logits_processors/utils.py b/tests/v1/logits_processors/utils.py index b8548bc319554..e54da72e5e2ed 100644 --- a/tests/v1/logits_processors/utils.py +++ b/tests/v1/logits_processors/utils.py @@ -27,7 +27,7 @@ DUMMY_LOGITPROC_ARG = "target_token" TEMP_GREEDY = 0.0 MAX_TOKENS = 20 DUMMY_LOGITPROC_ENTRYPOINT = "dummy_logitproc" -DUMMY_LOGITPROC_MODULE = "DummyModule" +DUMMY_LOGITPROC_MODULE = "tests.v1.logits_processors.utils" DUMMY_LOGITPROC_FQCN = f"{DUMMY_LOGITPROC_MODULE}:DummyLogitsProcessor" diff --git a/tests/v1/spec_decode/test_speculators_eagle3.py b/tests/v1/spec_decode/test_speculators_eagle3.py index 5ce6e1593b5c1..9a252cfffc8f0 100644 --- a/tests/v1/spec_decode/test_speculators_eagle3.py +++ b/tests/v1/spec_decode/test_speculators_eagle3.py @@ -5,6 +5,7 @@ import torch from vllm.config import SpeculativeConfig from vllm.model_executor.models.interfaces import supports_eagle3 +from vllm.platforms import current_platform @pytest.mark.parametrize( @@ -21,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3 pytest.param( "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", id="qwen3-eagle3-speculator-w4a16-verifier", + marks=pytest.mark.skipif( + current_platform.is_rocm(), + reason="The tests are skipped on rocm platform.", + ), ), ], ) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 0439bef1226e3..459abcfdd53cf 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -761,6 +761,10 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[1] == layer_1 +@pytest.mark.skipif( + current_platform.is_rocm(), + reason="Attention backend FLASHINFER is not supported on ROCm.", +) def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): """ The GPU model runner creates different views into the diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index a8f472d147a0d..35920d826578e 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -283,6 +283,28 @@ def _rocm_aiter_grouped_topk_fake( pass +# Cache whether aiter supports FP8 MLA parameters +_AITER_MLA_SUPPORTS_FP8: bool | None = None + + +def _check_aiter_mla_fp8_support() -> bool: + """Check if aiter.mla.mla_decode_fwd supports q_scale and kv_scale parameters.""" + global _AITER_MLA_SUPPORTS_FP8 + if _AITER_MLA_SUPPORTS_FP8 is None: + try: + import inspect + + from aiter.mla import mla_decode_fwd + + sig = inspect.signature(mla_decode_fwd) + _AITER_MLA_SUPPORTS_FP8 = ( + "q_scale" in sig.parameters and "kv_scale" in sig.parameters + ) + except Exception: + _AITER_MLA_SUPPORTS_FP8 = False + return _AITER_MLA_SUPPORTS_FP8 + + def _rocm_aiter_mla_decode_fwd_impl( q: torch.Tensor, kv_buffer: torch.Tensor, @@ -299,6 +321,16 @@ def _rocm_aiter_mla_decode_fwd_impl( ) -> None: from aiter.mla import mla_decode_fwd + kwargs = { + "sm_scale": sm_scale, + "logit_cap": logit_cap, + } + + # Only pass q_scale and kv_scale if the aiter library supports them + if _check_aiter_mla_fp8_support(): + kwargs["q_scale"] = q_scale + kwargs["kv_scale"] = kv_scale + mla_decode_fwd( q, kv_buffer.view(-1, 1, 1, q.shape[-1]), @@ -308,10 +340,7 @@ def _rocm_aiter_mla_decode_fwd_impl( kv_indices, kv_last_page_lens, max_seqlen_qo, - sm_scale=sm_scale, - logit_cap=logit_cap, - q_scale=q_scale, - kv_scale=kv_scale, + **kwargs, ) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 76068f86ebfb3..2625562aadd36 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -104,7 +104,8 @@ class FixFunctionalizationPass(VllmInductorPass): mutated_args = {1: "result"} self.defunctionalize(graph, node, mutated_args) elif ( - at_target + hasattr(torch.ops.vllm, "flashinfer_trtllm_fused_allreduce_norm") + and at_target == torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default ): mutated_args = { diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 00530846fce00..91f083a5534ba 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -30,7 +30,7 @@ CacheDType = Literal[ "fp8_ds_mla", ] MambaDType = Literal["auto", "float32"] -PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] +PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"] KVOffloadingBackend = Literal["native", "lmcache"] @@ -77,9 +77,21 @@ class CacheConfig: """Whether to enable prefix caching.""" prefix_caching_hash_algo: PrefixCachingHashAlgo = "sha256" """Set the hash algorithm for prefix caching:\n - - "sha256" uses Pickle for object serialization before hashing.\n + - "sha256" uses Pickle for object serialization before hashing. This is the + current default, as SHA256 is the most secure choice to avoid potential + hash collisions.\n - "sha256_cbor" provides a reproducible, cross-language compatible hash. It - serializes objects using canonical CBOR and hashes them with SHA-256.""" + serializes objects using canonical CBOR and hashes them with SHA-256.\n + - "xxhash" uses Pickle serialization with xxHash (128-bit) for faster, + non-cryptographic hashing. Requires the optional ``xxhash`` package. + IMPORTANT: Use of a hashing algorithm that is not considered + cryptographically secure theoretically increases the risk of hash collisions, + which can cause undefined behavior or even leak private information in + multi-tenant environments. Even if collisions are still very unlikely, it is + important to consider your security risk tolerance against the performance + benefits before turning this on.\n + - "xxhash_cbor" combines canonical CBOR serialization with xxHash for + reproducible hashing. Requires the optional ``xxhash`` package.""" cpu_offload_gb: float = Field(default=0, ge=0) """The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 963b091939e0e..d3d50e6ae7b2e 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -4,7 +4,7 @@ import enum from collections import Counter from collections.abc import Callable -from dataclasses import asdict, field +from dataclasses import field from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass -from vllm.config.utils import config, handle_deprecated +from vllm.config.utils import config, get_hash_factors, handle_deprecated, hash_factors from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.import_utils import resolve_obj_by_qualname @@ -196,7 +196,16 @@ class PassConfig: Any new fields that affect compilation should be added to the hash. Any future fields that don't affect compilation should be excluded. """ - return InductorPass.hash_dict(asdict(self)) + + ignored_fields = [ + "enable_fusion", + "enable_attn_fusion", + "enable_noop", + "enable_sequence_parallelism", + "enable_async_tp", + "enable_fi_allreduce_fusion", + ] + return hash_factors(get_hash_factors(self, ignored_factors=ignored_fields)) @field_validator( "fuse_norm_quant", @@ -267,14 +276,6 @@ class PassConfig: "v0.13.0 or v1.0.0, whichever is sooner", ) - # Force old flags to None to ensure they are not used - self.enable_fusion = None - self.enable_attn_fusion = None - self.enable_noop = None - self.enable_sequence_parallelism = None - self.enable_async_tp = None - self.enable_fi_allreduce_fusion = None - if not self.eliminate_noops: if self.fuse_norm_quant or self.fuse_act_quant: logger.warning_once( diff --git a/vllm/config/model.py b/vllm/config/model.py index 5de97697698a1..ae5189ce68d9a 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -84,7 +84,7 @@ TaskOption = Literal[ "transcription", "draft", ] -TokenizerMode = Literal["auto", "hf", "slow", "mistral"] +TokenizerMode = Literal["auto", "hf", "slow", "mistral", "deepseek_v32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] LogprobsMode = Literal[ "raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs" @@ -141,6 +141,7 @@ class ModelConfig: - "hf" will use the fast tokenizer if available.\n - "slow" will always use the slow tokenizer.\n - "mistral" will always use the tokenizer from `mistral_common`.\n + - "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n - Other custom values can be supported via plugins.""" trust_remote_code: bool = False """Trust remote code (e.g., from HuggingFace) when downloading the model @@ -1779,20 +1780,22 @@ class ModelConfig: return False elif attn_type == "decoder": pooling_type = self.pooler_config.pooling_type.lower() - if pooling_type in ["all", "mean", "step", "cls"]: + if pooling_type in ["mean", "step", "cls"]: logger.debug( "Pooling models with %s pooling does not " "support chunked prefill.", pooling_type, ) return False - else: - # pooling_type == "last" + elif pooling_type in ["all", "last"]: logger.debug( - "Pooling models with causal attn and last pooling support " - "chunked prefill." + "Pooling models with causal attn and %s pooling support " + "chunked prefill.", + pooling_type, ) return True + else: + raise ValueError(f"{pooling_type=} not supported.") # vllm currently does not have pooling models using hybrid, # attention_free or encoder_decoder attn types. return attn_type != "encoder_decoder" @@ -1816,20 +1819,22 @@ class ModelConfig: return False elif attn_type == "decoder": pooling_type = self.pooler_config.pooling_type.lower() - if pooling_type in ["all", "mean", "step", "cls"]: + if pooling_type in ["mean", "step", "cls"]: logger.debug( "Pooling models with %s pooling does not " "support prefix caching.", pooling_type, ) return False - else: - # pooling_type == "last" + elif pooling_type in ["all", "last"]: logger.debug( - "Pooling models with causal attn and last pooling support " - "prefix caching." + "Pooling models with causal attn and %s pooling support " + "prefix caching.", + pooling_type, ) return True + else: + raise ValueError(f"{pooling_type=} not supported.") # vllm currently does not have pooling models using hybrid, # attention_free or encoder_decoder attn types. return False diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 4a8c8bc17cfc3..20de672257107 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -593,10 +593,14 @@ class ParallelConfig: "max_parallel_loading_workers is currently " "not supported and will be ignored." ) - if self.distributed_executor_backend not in ("mp", "uni") and self.nnodes > 1: + allowed_backends = ("mp", "uni", "external_launcher") + if ( + self.distributed_executor_backend not in allowed_backends + and self.nnodes > 1 + ): raise ValueError( "nnodes > 1 can only be set when distributed executor " - "backend is mp or uni." + "backend is mp, uni or external_launcher." ) @property diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index df871dd7cbe4f..02f51a1dce112 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -190,3 +190,8 @@ KVConnectorFactory.register_connector( "vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector", "DecodeBenchConnector", ) +KVConnectorFactory.register_connector( + "MooncakeConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.mooncake_connector", + "MooncakeConnector", +) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index b8eb5ea3b4939..99d3be57c1381 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -4,13 +4,14 @@ KV cache helper for store. """ +from dataclasses import dataclass from typing import TYPE_CHECKING, Literal import torch -import vllm.envs as envs -from vllm import _custom_ops as ops -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.config import get_current_vllm_config from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -21,89 +22,6 @@ if TYPE_CHECKING: logger = init_logger(__name__) -class model_aware_kv_ops_helper: - def __init__(self, config: VllmConfig): - self.is_deepseek_mla = config.model_config.is_deepseek_mla - self.use_mla_opt = not envs.VLLM_MLA_DISABLE - self.tp_size = config.parallel_config.tensor_parallel_size - - def get_model_args(self, model_executable: torch.nn.Module): - model_config = model_executable.model.config - self.model_executable = model_executable - num_heads = int(model_config.num_key_value_heads / self.tp_size) - hidden_size = model_config.hidden_size - num_attention_heads = model_config.num_attention_heads - - # Deepseek's MLA (Multi-head Latent Attention) uses two different - # kv_cache shapes based on whether VLLM_MLA_DISABLE is set to 0. - # When VLLM_MLA_DISABLE=0 (default), forward absorb is applied, - # resulting in a kv_cache shape of [num_blks, blk_size, 1, - # kv_lora_rank + qk_rope_head_dim]. - # When VLLM_MLA_DISABLE=1, standard FA is used instead, leading - # to a kv_cache shape of [2, num_blks, blk_size, - # num_key_value_heads / tp, qk_nope_head_dim + qk_rope_head_dim]. - # For more details, see vllm/v1/attention/backends/mla/common.py. - if self.is_deepseek_mla and self.use_mla_opt: - head_size = model_config.kv_lora_rank + model_config.qk_rope_head_dim - num_heads = 1 - elif self.is_deepseek_mla and not self.use_mla_opt: - head_size = model_config.qk_nope_head_dim + model_config.qk_rope_head_dim - else: - head_size = getattr(model_config, "head_dim", None) - if head_size is None: - head_size = int(hidden_size // num_attention_heads) - - return num_heads, head_size - - def get_kv_from_cache(self, kv_cache, num_heads, head_size): - if self.is_deepseek_mla and self.use_mla_opt: - key_cache = kv_cache.reshape(-1, num_heads, head_size) - value_cache = kv_cache.reshape(-1, num_heads, head_size) - else: - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - return key_cache, value_cache - - def put_kv_to_cache( - self, - model_executable: torch.nn.Module, - keys, - values, - layer, - kv_cache, - slot_mapping, - start_pos, - end_pos, - ): - model_config = model_executable.model.config - - if self.is_deepseek_mla and self.use_mla_opt: - layer.self_attn.attn = layer.self_attn.mla_attn - k_c_normed_k_pe = keys.squeeze(1) - k_c_normed = k_c_normed_k_pe[:, : model_config.kv_lora_rank] - k_pe = k_c_normed_k_pe[:, model_config.kv_lora_rank :] - ops.concat_and_cache_mla( - k_c_normed.to(kv_cache.device), - k_pe.to(kv_cache.device), - kv_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - ) - else: - key_cache, value_cache = kv_cache[0], kv_cache[1] - ops.reshape_and_cache_flash( - keys.to(key_cache.device), - values.to(value_cache.device), - key_cache, - value_cache, - slot_mapping[start_pos:end_pos], - layer.self_attn.attn.kv_cache_dtype, - layer.self_attn.attn._k_scale, - layer.self_attn.attn._v_scale, - ) - - def get_kv_connector_cache_layout(): # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is # used for faster transfer. @@ -266,3 +184,124 @@ def copy_kv_blocks( src_tensor = src_kv_caches[layer_name] dst_tensor = dst_kv_caches[layer_name] copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) + + +@dataclass +class TpKVTopology: + """ + Helper class for tensor parallel and KV topology information for + mapping between local and remote TP workers. + """ + + tp_rank: int + remote_tp_size: dict[str, int] + is_mla: bool + total_num_kv_heads: int + attn_backend: type[AttentionBackend] + engine_id: str + remote_block_size: dict[str, int] + + def __post_init__(self): + # Figure out whether the first dimension of the cache is K/V + # or num_blocks. This is used to register the memory regions correctly. + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 + ) + # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], + # we just mock num_blocks to 1 for the dimension check below. + self._is_kv_layout_blocks_first = ( + len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 + ) + + attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] + self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS + + @property + def is_kv_layout_blocks_first(self) -> bool: + return self._is_kv_layout_blocks_first + + @property + def split_k_and_v(self) -> bool: + # Whether to register regions for K and V separately (when present). + return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) + + @property + def tp_size(self) -> int: + return self.remote_tp_size[self.engine_id] + + @property + def block_size(self) -> int: + return self.remote_block_size[self.engine_id] + + def tp_ratio( + self, + remote_tp_size: int, + ) -> int: + """ + Calculate the tensor parallel ratio between local and remote TP. + We can think of it as the number of local TP workers-per-remote TP + workers. Local workers will read from the same remote TP worker in + groups of size `tp_ratio`. + """ + assert self.tp_size % remote_tp_size == 0, ( + f"Local tensor parallel size {self.tp_size} is not divisible " + f"by remote tensor parallel size {remote_tp_size}." + ) + return self.tp_size // remote_tp_size + + def block_size_ratio( + self, + remote_block_size: int, + ) -> float: + """ + Calculate the block size ratio between local and remote TP. + """ + assert self.block_size % remote_block_size == 0, ( + f"Local block size {self.block_size} is not divisible " + f"by remote block size {remote_block_size} or vice versa." + ) + return self.block_size // remote_block_size + + def tp_ratio_from_engine_id( + self, + remote_engine_id: str, + ) -> int: + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.tp_ratio(remote_tp_size) + + def block_size_ratio_from_engine_id( + self, + remote_engine_id: str, + ) -> float: + remote_block_size = self.remote_block_size[remote_engine_id] + return self.block_size_ratio(remote_block_size) + + def is_kv_replicated(self, engine_id: str) -> bool: + """ + Whether the KV cache is replicated across TP workers due to the + number of TP workers being greater than the number of KV heads. + """ + tp_size = self.remote_tp_size[engine_id] + return tp_size // self.total_num_kv_heads >= 1 + + def replicates_kv_cache(self, remote_engine_id: str) -> bool: + # MLA is always replicated as the hidden dim can't be split. + return self.is_mla or self.is_kv_replicated(remote_engine_id) + + def get_target_remote_rank( + self, + remote_tp_size: int, + ) -> int: + """ + Get the remote TP rank (on P) that the current local TP rank + (on D) will read from. + """ + tp_ratio = self.tp_ratio(remote_tp_size) + return self.tp_rank // tp_ratio + + def get_target_remote_rank_from_engine_id( + self, + remote_engine_id: str, + ) -> int: + remote_tp_size = self.remote_tp_size[remote_engine_id] + return self.get_target_remote_rank(remote_tp_size) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py new file mode 100644 index 0000000000000..705960aebe2da --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py @@ -0,0 +1,914 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import threading +import time +import uuid +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import msgspec +import numpy as np +import torch +import zmq +import zmq.asyncio + +from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.selector import get_attn_backend +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from vllm.forward_context import ForwardContext +from vllm.logger import init_logger +from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket +from vllm.v1.attention.backends.utils import get_kv_cache_layout +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import RequestStatus + +try: + from mooncake.engine import TransferEngine +except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run VLLM with MooncakeTransferEngine." + ) from e + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +EngineId = str +ReqId = str + +TRANS_DONE = b"trans_done" +TRANS_ERROR = b"trans_error" + +logger = init_logger(__name__) + + +class MooncakeAgentMetadata( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): + remote_hostname: str + remote_port: int + request_ids: list[ReqId] + kv_caches_base_addr: list[int] + block_ids: list[list[int]] + + +@dataclass +class RecvReqMeta: + local_block_ids: list[int] + remote_host: str + remote_port: int + + +@dataclass +class SendBlockMeta: + local_block_ids: list[int] + ready: threading.Event + expire_time: float = float("inf") + + +@dataclass +class SendReqMeta: + reqs: dict[ReqId, SendBlockMeta] + lock: threading.Lock + + +@dataclass +class FinishedSendReqSet: + set: set[ReqId] + lock: threading.Lock + + +@dataclass +class FinishedReceiveReqSet: + set: set[ReqId] + lock: asyncio.Lock + + +class MooncakeConnectorMetadata(KVConnectorMetadata): + def __init__(self): + self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {} + self.reqs_to_send: dict[ReqId, list[int]] = {} + + def add_new_req( + self, + request_id: ReqId, + local_block_ids: list[int], + kv_transfer_params: dict[str, Any], + load_remote_cache: bool = True, + ): + if load_remote_cache: + self.reqs_to_recv[request_id] = RecvReqMeta( + local_block_ids=local_block_ids, + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], + ) + else: + self.reqs_to_send[request_id] = local_block_ids + + +class MooncakeConnector(KVConnectorBase_V1): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + + assert vllm_config.kv_transfer_config is not None + assert vllm_config.kv_transfer_config.engine_id is not None + self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler: MooncakeConnectorScheduler | None = ( + MooncakeConnectorScheduler(vllm_config, self.engine_id) + ) + self.connector_worker: MooncakeConnectorWorker | None = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens + ) + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens + ) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[set[str] | None, set[str] | None]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._connector_metadata, MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._connector_metadata) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeConnector does not do layerwise saving.""" + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> None: + """MooncakeConnector does not save explicitly.""" + pass + + def wait_for_save(self): + pass + + +class MooncakeConnectorScheduler: + """Implementation of Scheduler side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + self.vllm_config = vllm_config + self.engine_id: EngineId = engine_id + self.side_channel_host = get_ip() + self.side_channel_port = get_mooncake_side_channel_port(vllm_config) + + assert vllm_config.kv_transfer_config + self.kv_role = vllm_config.kv_transfer_config.kv_role + logger.info("Initializing Mooncake Transfer Engine Scheduler %s", engine_id) + + # Requests that need to start recv/send. + # New requests are added by update_state_after_alloc in + # the scheduler. Used to make metadata passed to Worker. + self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + self._reqs_need_send: dict[ReqId, list[int]] = {} + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int, bool]: + """ + For remote prefill, pull all prompt blocks from remote + asynchronously relative to engine execution. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + Returns: + * the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + * true if the external KV cache tokens will be loaded + asynchronously (between scheduler steps). + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector get_num_new_matched_tokens: " + "num_computed_tokens=%s, kv_transfer_params=%s", + num_computed_tokens, + params, + ) + + if params is not None and params.get("do_remote_prefill"): + # Remote prefill: get all prompt blocks from remote. + token_ids = request.prompt_token_ids or [] + count = len(token_ids) - num_computed_tokens + if count > 0: + return count, True + + # No remote prefill for this request. + return 0, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector update_state_after_alloc: " + "num_external_tokens=%s, kv_transfer_params=%s", + num_external_tokens, + params, + ) + + if not params: + return + + if params.get("do_remote_prefill"): + assert self.kv_role != "kv_producer" + if all(p in params for p in ("remote_host", "remote_port")): + # If remote_blocks and num_external_tokens = 0, we have + # a full prefix cache hit on the D worker. We need to call + # send_notif in _read_blocks to free the memory on the P. + local_block_ids = ( + blocks.get_unhashed_block_ids() if num_external_tokens > 0 else [] + ) + # Get unhashed blocks to pull from remote. + self._reqs_need_recv[request.request_id] = (request, local_block_ids) + else: + logger.warning( + "Got invalid KVTransferParams: %s. This " + "request will not utilize KVTransfer", + params, + ) + # Only trigger 1 KV transfer per request. + params["do_remote_prefill"] = False + + elif params.get("do_remote_decode"): + # Add an empty list to worker to create event. + self._reqs_need_send[request.request_id] = [] + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + meta = MooncakeConnectorMetadata() + + # Loop through scheduled reqs and convert to RecvReqMeta. + if self.kv_role != "kv_producer": + for req_id, (req, block_ids) in self._reqs_need_recv.items(): + assert req.kv_transfer_params is not None + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params=req.kv_transfer_params, + ) + self._reqs_need_recv.clear() + + if self.kv_role != "kv_consumer": + for req_id, block_ids in self._reqs_need_send.items(): + meta.add_new_req( + request_id=req_id, + local_block_ids=block_ids, + kv_transfer_params={}, + load_remote_cache=False, + ) + self._reqs_need_send.clear() + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + + params = request.kv_transfer_params + logger.debug( + "MooncakeConnector request_finished, request_status=%s, " + "kv_transfer_params=%s", + request.status, + params, + ) + if not params: + return False, None + + if params.get("do_remote_prefill"): + # If do_remote_prefill is still True when the request is finished, + # update_state_after_alloc must not have been called (the request + # must have been aborted before it was scheduled). + # To avoid stranding the prefill blocks in the prefill instance, + # we must add empty block_ids to _reqs_need_recv so that our + # worker side will notify and free blocks in the prefill instance. + assert self.kv_role != "kv_producer" + self._reqs_need_recv[request.request_id] = (request, []) + params["do_remote_prefill"] = False + return False, None + + if ( + not params.get("do_remote_decode") + or request.status != RequestStatus.FINISHED_LENGTH_CAPPED + ): + return False, None + + assert self.kv_role != "kv_consumer" + + # TODO: check whether block_ids actually ever be 0. If not we could + # remove the conditional below + delay_free_blocks = len(block_ids) > 0 + + if delay_free_blocks: + self._reqs_need_send[request.request_id] = block_ids + + return delay_free_blocks, dict( + do_remote_prefill=True, + do_remote_decode=False, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, + ) + + +class MooncakeConnectorWorker: + """Implementation of Worker side methods""" + + def __init__(self, vllm_config: VllmConfig, engine_id: str): + logger.info("Initializing Mooncake Transfer Engine worker %s", engine_id) + + self.vllm_config = vllm_config + + self.engine = TransferEngine() + self.hostname = get_ip() + ret_value = self.engine.initialize(self.hostname, "P2PHANDSHAKE", "rdma", "") + if ret_value != 0: + raise RuntimeError("Mooncake Transfer Engine initialization failed.") + + self.rpc_port = self.engine.get_rpc_port() + + logger.debug( + "Mooncake Transfer Engine initialized at %s:%d", + self.hostname, + self.rpc_port, + ) + + # Mooncake handshake port. + self.side_channel_port: int = get_mooncake_side_channel_port(vllm_config) + + self.engine_id: EngineId = engine_id + self.tp_rank = get_tensor_model_parallel_rank() + self.world_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() + self.num_blocks = 0 + + assert vllm_config.kv_transfer_config + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "num_workers", 10 + ) + + self.kv_caches_base_addr: list[int] = [] + self.device_kv_caches: dict[str, torch.Tensor] = {} + self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock()) + + # For kv_both, we will act both prefiller and decoder. + if self.kv_role != "kv_consumer": + # Background thread for sending kvcaches to D. + self._mooncake_sender_t: threading.Thread | None = None + # Background thread for processing new sending requests. + self._sender_executor = ThreadPoolExecutor( + max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender" + ) + logger.debug( + "Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers + ) + if self.kv_role != "kv_producer": + self.receiver_loop = asyncio.new_event_loop() + self._mooncake_receiver_t = threading.Thread( + target=self._receiver_loop, args=(self.receiver_loop,), daemon=True + ) + self._mooncake_receiver_t.start() + logger.debug("Mooncake Decoder: start receiver thread") + + self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet( + set(), threading.Lock() + ) + self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet( + set(), asyncio.Lock() + ) + + self.block_size = vllm_config.cache_config.block_size + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.use_mla = self.model_config.use_mla + + backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + use_mla=self.use_mla, + ) + self.backend_name = backend.get_name() + self.kv_cache_layout = get_kv_cache_layout() + logger.debug("Detected attention backend %s", self.backend_name) + logger.debug("Detected kv cache layout %s", self.kv_cache_layout) + + self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} + self.kv_topo = TpKVTopology( + tp_rank=self.tp_rank, + engine_id=self.engine_id, + remote_tp_size=self._tp_size, # shared state + remote_block_size=self._block_size, # shared state + is_mla=self.use_mla, + total_num_kv_heads=self.model_config.get_total_num_kv_heads(), + attn_backend=backend, + ) + self._use_pallas = self.kv_topo._use_pallas + + self.zmq_ctx = zmq.Context() + self.async_zmq_ctx = zmq.asyncio.Context() + self._encoder = msgspec.msgpack.Encoder() + self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) + + def __del__(self): + self.shutdown() + + def shutdown(self): + """Cleanup background threads on destruction.""" + self.zmq_ctx.term() + self.async_zmq_ctx.term() + if self.kv_role != "kv_consumer": + self._sender_executor.shutdown(wait=False) + if self._mooncake_sender_t: + self._mooncake_sender_t.join() + if self.kv_role != "kv_producer" and self.receiver_loop.is_running(): + self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop) + self._mooncake_receiver_t.join() + + def _receiver_loop(self, loop: asyncio.AbstractEventLoop): + asyncio.set_event_loop(loop) + loop.run_forever() + + def _mooncake_sender( + self, ready_event: threading.Event, base_port: int, tp_rank: int + ): + """ + Background thread that listens for Mooncake requests, dispatches them + to a thread pool, and sends acknowledgments upon completion. + """ + + frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank) + frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER) + logger.debug("Mooncake sender starting listening on path: %s", frontend_path) + + backend_path = make_zmq_path("inproc", str(uuid.uuid4())) + backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL) + + poller = zmq.Poller() + poller.register(frontend, zmq.POLLIN) + poller.register(backend, zmq.POLLIN) + + ready_event.set() + + try: + while True: + sockets = dict(poller.poll()) + + if frontend in sockets: + identity, _, metadata_bytes = frontend.recv_multipart() + self._sender_executor.submit( + self._sender_worker, + identity, + metadata_bytes, + backend_path, + ) + + if backend in sockets: + identity, status = backend.recv_multipart() + frontend.send_multipart((identity, b"", status)) + + except zmq.ContextTerminated: + logger.debug("ZMQ context terminated, exiting Mooncake sender thread.") + except Exception as e: + logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e)) + finally: + frontend.close() + backend.close() + + def _sender_worker( + self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str + ): + status = TRANS_ERROR + + try: + metadata = self._decoder.decode(metadata_bytes) + self.send_kv_to_decode(metadata) + status = TRANS_DONE + except Exception as e: + logger.error("Error processing Mooncake handshake: %s", e) + finally: + pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH) + try: + pusher.send_multipart((identity, status)) + except zmq.ZMQError as e: + logger.warning( + "Internal error, maybe the server is shutting down. Error: %s", + e, + ) + finally: + pusher.close() + + def send_kv_to_decode(self, meta: MooncakeAgentMetadata): + send_reqs: list[tuple[ReqId, SendBlockMeta]] = [] + with self.reqs_need_send.lock: + for req_id in meta.request_ids: + send_meta = self.reqs_need_send.reqs.get(req_id) + if send_meta is None: + logger.warning("Request %s not found in reqs_need_send", req_id) + return + # Mark it as not expired. We will send it now. + send_meta.expire_time = float("inf") + send_reqs.append((req_id, send_meta)) + + self._send_blocks(send_reqs, meta) + + with self.reqs_need_send.lock: + for req_id in meta.request_ids: + del self.reqs_need_send.reqs[req_id] + + with self.finished_sending_reqs.lock: + self.finished_sending_reqs.set.update(meta.request_ids) + + def _send_blocks( + self, + send_reqs: list[tuple[ReqId, SendBlockMeta]], + agent_meta: MooncakeAgentMetadata, + ): + src_ptrs = [] + dst_ptrs = [] + lengths = [] + local_base_addr = self.kv_caches_base_addr + remote_base_addr = agent_meta.kv_caches_base_addr + block_len = self.block_len + remote_session = f"{agent_meta.remote_hostname}:{agent_meta.remote_port}" + + assert len(send_reqs) == len(agent_meta.block_ids) + for (req_id, send_meta), remote_block_ids in zip( + send_reqs, agent_meta.block_ids + ): + send_meta.ready.wait() + + num_remote_blocks = len(remote_block_ids) + if num_remote_blocks == 0: + continue + + local_block_ids = send_meta.local_block_ids + # Partial prefix cache hit: just read uncomputed blocks. + num_local_blocks = len(local_block_ids) + assert num_local_blocks >= num_remote_blocks + if num_local_blocks > num_remote_blocks: + local_block_ids = local_block_ids[-num_remote_blocks:] + + # Group by indices + group_local_block_ids, group_remote_block_ids = group_concurrent_contiguous( + local_block_ids, remote_block_ids + ) + + for local_layer_addr, remote_layer_addr in zip( + local_base_addr, remote_base_addr + ): + for group_local_block_id, group_remote_block_id in zip( + group_local_block_ids, group_remote_block_ids + ): + src_ptrs.append( + local_layer_addr + group_local_block_id[0] * block_len + ) + dst_ptrs.append( + remote_layer_addr + group_remote_block_id[0] * block_len + ) + lengths.append(block_len * len(group_local_block_id)) + + logger.debug( + "Sending kv_caches for request %s (%d blocks) to %s", + req_id, + num_remote_blocks, + remote_session, + ) + + start_time = time.perf_counter() + ret_value = self.engine.batch_transfer_sync_write( + remote_session, src_ptrs, dst_ptrs, lengths + ) + if ret_value != 0: + raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}") + + logger.debug( + "Sending to %s done, took %s", + remote_session, + time.perf_counter() - start_time, + ) + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """Register the KV Cache data in mooncake.""" + + logger.info("Registering KV_Caches. use_mla: %s", self.use_mla) + + kv_data_ptrs = [] + kv_data_lens = [] + seen_base_addresses = [] + + split_k_and_v = self.kv_topo.split_k_and_v + tensor_size_bytes = None + for layer_name, cache_or_caches in kv_caches.items(): + logger.debug( + "registering layer %s with shape %s", layer_name, cache_or_caches.shape + ) + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + + for cache in cache_list: + base_addr = cache.data_ptr() + if base_addr in seen_base_addresses: + continue + + seen_base_addresses.append(base_addr) + curr_tensor_size_bytes = cache.nbytes + + if tensor_size_bytes is None: + tensor_size_bytes = curr_tensor_size_bytes + self.num_blocks = cache.shape[0] + + assert tensor_size_bytes == curr_tensor_size_bytes, ( + "All kv cache tensors must have the same size" + ) + kernel_block_size = cache.shape[-2 if self.use_mla else -3] + assert self.block_size == kernel_block_size + kv_data_ptrs.append(base_addr) + kv_data_lens.append(tensor_size_bytes) + + self.kv_caches_base_addr = seen_base_addresses + + ret_value = self.engine.batch_register_memory(kv_data_ptrs, kv_data_lens) + if ret_value != 0: + raise RuntimeError("Mooncake batch memory registration failed.") + + assert tensor_size_bytes is not None + assert self.num_blocks != 0 + assert tensor_size_bytes % self.num_blocks == 0 + self.block_len = tensor_size_bytes // self.num_blocks + self.device_kv_caches = kv_caches + logger.debug( + "registered num_blocks=%d block_len=%d", self.num_blocks, self.block_len + ) + + # No need to launch server for D node. + if self.kv_role == "kv_consumer": + return + + ready_event = threading.Event() + self._mooncake_sender_t = threading.Thread( + target=self._mooncake_sender, + args=(ready_event, self.side_channel_port, self.tp_rank), + daemon=True, + name="mooncake_sender", + ) + self._mooncake_sender_t.start() + ready_event.wait() # Wait for listener ZMQ socket to be ready. + + async def fetch_finished_recving_reqs(self) -> set[ReqId]: + async with self.finished_recving_reqs.lock: + finished_recving_reqs = self.finished_recving_reqs.set + self.finished_recving_reqs.set = set() + return finished_recving_reqs + + def get_finished(self) -> tuple[set[str] | None, set[str] | None]: + """ + Get requests that are done sending or recving on this specific worker. + The scheduler process (via the MultiprocExecutor) will use this output + to track which workers are done. + """ + fut = None + if self.kv_role != "kv_producer": + fut = asyncio.run_coroutine_threadsafe( + self.fetch_finished_recving_reqs(), self.receiver_loop + ) + + if self.kv_role != "kv_consumer": + with self.finished_sending_reqs.lock: + finished_sending_reqs = self.finished_sending_reqs.set + self.finished_sending_reqs.set = set() + else: + finished_sending_reqs = set() + + finished_recving_reqs = fut.result() if fut else set() + + if finished_sending_reqs or finished_recving_reqs: + logger.debug( + "Rank %s, get_finished: %s requests done sending " + "and %s requests done recving", + self.tp_rank, + len(finished_sending_reqs), + len(finished_recving_reqs), + ) + + # Handle timeout to avoid stranding blocks on remote. + now = time.perf_counter() + with self.reqs_need_send.lock: + expired_reqs = [ + req_id + for req_id, send_meta in self.reqs_need_send.reqs.items() + if send_meta.expire_time < now + ] + for req_id in expired_reqs: + logger.warning( + "Request %s timed out after %d seconds without " + "being sent. Freeing its blocks on the producer side.", + req_id, + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT, + ) + del self.reqs_need_send.reqs[req_id] + if expired_reqs: + finished_sending_reqs.update(expired_reqs) + + return finished_sending_reqs or None, finished_recving_reqs or None + + async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]): + req_ids, block_ids = map(list, zip(*req_blocks)) + metadata = MooncakeAgentMetadata( + remote_hostname=self.hostname, + remote_port=self.rpc_port, + request_ids=req_ids, + kv_caches_base_addr=self.kv_caches_base_addr, + block_ids=block_ids, + ) + + encoded_data = self._encoder.encode(metadata) + logger.debug( + "Size of encoded MooncakeAgentMetadata: %d bytes", len(encoded_data) + ) + logger.debug("Sending kv transfer request for %s on path: %s", req_ids, path) + + # Send query for the request. + sock: zmq.asyncio.Socket = make_zmq_socket( + self.async_zmq_ctx, path, zmq.REQ, bind=False, linger=0 + ) + sock.setsockopt(zmq.RCVTIMEO, 60000) + try: + await sock.send(encoded_data) + ret_msg = await sock.recv() + if ret_msg != TRANS_DONE: + logger.error( + "Error happens during tranfering kvcache for %s, see logs in prefiller.", # noqa: E501 + req_ids, + ) + return + except zmq.ContextTerminated: + logger.debug("ZMQ context terminated, exiting Mooncake receiver thread.") + except Exception as e: + logger.error("MooncakeAgentMetadata transfer failed for %s: %s", req_ids, e) + return + finally: + sock.close() + + async with self.finished_recving_reqs.lock: + self.finished_recving_reqs.set.update(req_ids) + + logger.debug("pulling kv_caches for %s finished", req_ids) + + def group_kv_pull(self, metadata: MooncakeConnectorMetadata): + kv_pulls = defaultdict(list) + for req_id, meta in metadata.reqs_to_recv.items(): + logger.debug( + "start_load_kv for request %s from remote engine. " + "Num local_block_ids: %s.", + req_id, + len(meta.local_block_ids), + ) + path = make_zmq_path( + "tcp", meta.remote_host, meta.remote_port + self.tp_rank + ) + kv_pulls[path].append((req_id, meta.local_block_ids)) + + return kv_pulls + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + if self.kv_role != "kv_producer": + kv_pulls = self.group_kv_pull(metadata) + for path, req_blocks in kv_pulls.items(): + asyncio.run_coroutine_threadsafe( + self.receive_kv(path, req_blocks), self.receiver_loop + ) + + if self.kv_role != "kv_consumer": + with self.reqs_need_send.lock: + for req_id, block_ids in metadata.reqs_to_send.items(): + if block_ids: + # Already gone through request_finished() + send_meta = self.reqs_need_send.reqs[req_id] + send_meta.local_block_ids = block_ids + send_meta.ready.set() + send_meta.expire_time = ( + time.perf_counter() + + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT + ) + else: + # From update_state_after_alloc(), + # but not reach request_finished() yet + self.reqs_need_send.reqs[req_id] = SendBlockMeta( + local_block_ids=[], ready=threading.Event() + ) + + +def group_concurrent_contiguous( + src_indices: list[int], dst_indices: list[int] +) -> tuple[list[list[int]], list[list[int]]]: + """Vectorised NumPy implementation.""" + if len(src_indices) == 0: + return [], [] + + brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1 + src_groups = np.split(src_indices, brk) + dst_groups = np.split(dst_indices, brk) + + src_groups = [g.tolist() for g in src_groups] + dst_groups = [g.tolist() for g in dst_groups] + + return src_groups, dst_groups + + +def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int: + # This logic is now centralized + return ( + envs.VLLM_MOONCAKE_BOOTSTRAP_PORT + + vllm_config.parallel_config.data_parallel_rank + * vllm_config.parallel_config.tensor_parallel_size + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 41e32bb73d40b..24b7599a4fe0c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -20,10 +20,10 @@ import torch import zmq from vllm import envs -from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata -from vllm.attention.backends.registry import AttentionBackendEnum +from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.selector import get_attn_backend from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, KVConnectorBase_V1, @@ -668,128 +668,6 @@ class NixlConnectorScheduler: class NixlConnectorWorker: """Implementation of Worker side methods""" - @dataclass - class TpKVTopology: - """ - Helper class for tensor parallel and KV topology information for - mapping between local and remote TP workers. - """ - - tp_rank: int - remote_tp_size: dict[EngineId, int] - is_mla: bool - total_num_kv_heads: int - attn_backend: type[AttentionBackend] - engine_id: EngineId - remote_block_size: dict[EngineId, int] - - def __post_init__(self): - # Figure out whether the first dimension of the cache is K/V - # or num_blocks. This is used to register the memory regions correctly. - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks=1, block_size=16, num_kv_heads=1, head_size=1 - ) - # Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D], - # we just mock num_blocks to 1 for the dimension check below. - self._is_kv_layout_blocks_first = ( - len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 - ) - - attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] - self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS - - @property - def is_kv_layout_blocks_first(self) -> bool: - return self._is_kv_layout_blocks_first - - @property - def split_k_and_v(self) -> bool: - # Whether to register regions for K and V separately (when present). - return not ( - self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first - ) - - @property - def tp_size(self) -> int: - return self.remote_tp_size[self.engine_id] - - @property - def block_size(self) -> int: - return self.remote_block_size[self.engine_id] - - def tp_ratio( - self, - remote_tp_size: int, - ) -> int: - """ - Calculate the tensor parallel ratio between local and remote TP. - We can think of it as the number of local TP workers-per-remote TP - workers. Local workers will read from the same remote TP worker in - groups of size `tp_ratio`. - """ - assert self.tp_size % remote_tp_size == 0, ( - f"Local tensor parallel size {self.tp_size} is not divisible " - f"by remote tensor parallel size {remote_tp_size}." - ) - return self.tp_size // remote_tp_size - - def block_size_ratio( - self, - remote_block_size: int, - ) -> float: - """ - Calculate the block size ratio between local and remote TP. - """ - assert self.block_size % remote_block_size == 0, ( - f"Local block size {self.block_size} is not divisible " - f"by remote block size {remote_block_size} or vice versa." - ) - return self.block_size // remote_block_size - - def tp_ratio_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> int: - remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.tp_ratio(remote_tp_size) - - def block_size_ratio_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> float: - remote_block_size = self.remote_block_size[remote_engine_id] - return self.block_size_ratio(remote_block_size) - - def is_kv_replicated(self, engine_id: EngineId) -> bool: - """ - Whether the KV cache is replicated across TP workers due to the - number of TP workers being greater than the number of KV heads. - """ - tp_size = self.remote_tp_size[engine_id] - return tp_size // self.total_num_kv_heads >= 1 - - def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: - # MLA is always replicated as the hidden dim can't be split. - return self.is_mla or self.is_kv_replicated(remote_engine_id) - - def get_target_remote_rank( - self, - remote_tp_size: int, - ) -> int: - """ - Get the remote TP rank (on P) that the current local TP rank - (on D) will read from. - """ - tp_ratio = self.tp_ratio(remote_tp_size) - return self.tp_rank // tp_ratio - - def get_target_remote_rank_from_engine_id( - self, - remote_engine_id: EngineId, - ) -> int: - remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.get_target_remote_rank(remote_tp_size) - def __init__(self, vllm_config: VllmConfig, engine_id: str): if NixlWrapper is None: logger.error("NIXL is not available") @@ -958,7 +836,7 @@ class NixlConnectorWorker: self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int) self.xfer_stats = NixlKVConnectorStats() - self.kv_topo = self.TpKVTopology( + self.kv_topo = TpKVTopology( tp_rank=self.tp_rank, engine_id=self.engine_id, remote_tp_size=self._tp_size, # shared state diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c82a77c216af2..f910f10407d44 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1169,17 +1169,13 @@ def init_distributed_environment( from vllm.config import get_current_vllm_config config = get_current_vllm_config() - if config is not None and config.parallel_config.nnodes > 1: - parallel_config = config.parallel_config - ip = parallel_config.master_addr - rank = parallel_config.data_parallel_rank * world_size + rank - world_size = parallel_config.world_size_across_dp - port = parallel_config.master_port - distributed_init_method = get_distributed_init_method(ip, port) - elif ( + if ( config is not None - and config.parallel_config.data_parallel_size > 1 and config.parallel_config.distributed_executor_backend != "external_launcher" + and ( + config.parallel_config.nnodes > 1 + or config.parallel_config.data_parallel_size > 1 + ) ): parallel_config = config.parallel_config # adjust to take into account data parallelism @@ -1187,15 +1183,22 @@ def init_distributed_environment( rank = parallel_config.data_parallel_rank * world_size + rank # adjust the world size to take into account data parallelism world_size = parallel_config.world_size_across_dp - ip = parallel_config.data_parallel_master_ip - port = parallel_config.get_next_dp_init_port() - distributed_init_method = get_distributed_init_method(ip, port) - logger.debug( - "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", - world_size, - rank, - distributed_init_method, - ) + + # Use appropriate IP and port based on configuration + if parallel_config.nnodes > 1: + ip = parallel_config.master_addr + port = parallel_config.master_port + distributed_init_method = get_distributed_init_method(ip, port) + else: + ip = parallel_config.data_parallel_master_ip + port = parallel_config.get_next_dp_init_port() + distributed_init_method = get_distributed_init_method(ip, port) + logger.debug( + "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", + world_size, + rank, + distributed_init_method, + ) if not torch.distributed.is_initialized(): logger.info( "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s", diff --git a/vllm/entrypoints/anthropic/serving_messages.py b/vllm/entrypoints/anthropic/serving_messages.py index 340dabf0e7117..e7ea3bb59ca70 100644 --- a/vllm/entrypoints/anthropic/serving_messages.py +++ b/vllm/entrypoints/anthropic/serving_messages.py @@ -183,7 +183,9 @@ class AnthropicServingMessages(OpenAIServingChat): if anthropic_request.stream: req.stream = anthropic_request.stream - req.stream_options = StreamOptions.validate({"include_usage": True}) + req.stream_options = StreamOptions.validate( + {"include_usage": True, "continuous_usage_stats": True} + ) if anthropic_request.tool_choice is None: req.tool_choice = None @@ -323,6 +325,12 @@ class AnthropicServingMessages(OpenAIServingChat): content=[], model=origin_chunk.model, ), + usage=AnthropicUsage( + input_tokens=origin_chunk.usage.prompt_tokens + if origin_chunk.usage + else 0, + output_tokens=0, + ), ) first_item = False data = chunk.model_dump_json(exclude_unset=True) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 2dd5b9c8f8aa0..077fe681bc5b8 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -536,7 +536,7 @@ def resolve_hf_chat_template( def _resolve_chat_template_content_format( chat_template: str | None, tools: list[dict[str, Any]] | None, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: @@ -593,7 +593,7 @@ def resolve_chat_template_content_format( chat_template: str | None, tools: list[dict[str, Any]] | None, given_format: ChatTemplateContentFormatOption, - tokenizer: TokenizerLike, + tokenizer: TokenizerLike | None, *, model_config: ModelConfig, ) -> _ChatTemplateContentFormat: @@ -627,11 +627,10 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): maximum per prompt. """ - def __init__(self, model_config: ModelConfig, tokenizer: TokenizerLike): + def __init__(self, model_config: ModelConfig): super().__init__() self._model_config = model_config - self._tokenizer = tokenizer self._items_by_modality = defaultdict[str, list[_T | None]](list) self._uuids_by_modality = defaultdict[str, list[str | None]](list) @@ -1139,11 +1138,19 @@ def validate_chat_template(chat_template: Path | str | None): not any(c in chat_template for c in JINJA_CHARS) and not Path(chat_template).exists() ): - raise ValueError( - f"The supplied chat template string ({chat_template}) " - f"appears path-like, but doesn't exist!" + # Try to find the template in the built-in templates directory + from vllm.transformers_utils.chat_templates.registry import ( + CHAT_TEMPLATES_DIR, ) + builtin_template_path = CHAT_TEMPLATES_DIR / chat_template + if not builtin_template_path.exists(): + raise ValueError( + f"The supplied chat template string ({chat_template}) " + f"appears path-like, but doesn't exist! " + f"Tried: {chat_template} and {builtin_template_path}" + ) + else: raise TypeError(f"{type(chat_template)} is not a valid chat template type") @@ -1173,12 +1180,23 @@ def _load_chat_template( JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): - msg = ( - f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}" + # Try to load from the built-in templates directory + from vllm.transformers_utils.chat_templates.registry import ( + CHAT_TEMPLATES_DIR, ) - raise ValueError(msg) from e + + builtin_template_path = CHAT_TEMPLATES_DIR / chat_template + try: + with open(builtin_template_path) as f: + return f.read() + except OSError: + msg = ( + f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be opened. " + f"Tried: {chat_template} and {builtin_template_path}. " + f"Reason: {e}" + ) + raise ValueError(msg) from e # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly @@ -1593,7 +1611,6 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: def parse_chat_messages( messages: list[ChatCompletionMessageParam], model_config: ModelConfig, - tokenizer: TokenizerLike, content_format: _ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], @@ -1601,7 +1618,7 @@ def parse_chat_messages( MultiModalUUIDDict | None, ]: conversation: list[ConversationMessage] = [] - mm_tracker = MultiModalItemTracker(model_config, tokenizer) + mm_tracker = MultiModalItemTracker(model_config) for msg in messages: sub_messages = _parse_chat_message_content( @@ -1625,7 +1642,6 @@ def parse_chat_messages( def parse_chat_messages_futures( messages: list[ChatCompletionMessageParam], model_config: ModelConfig, - tokenizer: TokenizerLike, content_format: _ChatTemplateContentFormat, ) -> tuple[ list[ConversationMessage], @@ -1633,7 +1649,7 @@ def parse_chat_messages_futures( MultiModalUUIDDict | None, ]: conversation: list[ConversationMessage] = [] - mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + mm_tracker = AsyncMultiModalItemTracker(model_config) for msg in messages: sub_messages = _parse_chat_message_content( diff --git a/vllm/entrypoints/harmony_utils.py b/vllm/entrypoints/harmony_utils.py index 47a252348c102..bb932e39e0472 100644 --- a/vllm/entrypoints/harmony_utils.py +++ b/vllm/entrypoints/harmony_utils.py @@ -328,6 +328,105 @@ def render_for_completion(messages: list[Message]) -> list[int]: return token_ids +def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutputItem: + """Parse browser tool calls (search, open, find) into web search items.""" + if len(message.content) != 1: + raise ValueError("Invalid number of contents in browser message") + content = message.content[0] + + # Parse JSON args (with retry detection) + try: + browser_call = json.loads(content.text) + except json.JSONDecodeError: + json_retry_output_message = ( + f"Invalid JSON args, caught and retried: {content.text}" + ) + browser_call = { + "query": json_retry_output_message, + "url": json_retry_output_message, + "pattern": json_retry_output_message, + } + + # Create appropriate action based on recipient + if recipient == "browser.search": + action = ActionSearch( + query=f"cursor:{browser_call.get('query', '')}", type="search" + ) + elif recipient == "browser.open": + action = ActionOpenPage( + url=f"cursor:{browser_call.get('url', '')}", type="open_page" + ) + elif recipient == "browser.find": + action = ActionFind( + pattern=browser_call.get("pattern", ""), + url=f"cursor:{browser_call.get('url', '')}", + type="find", + ) + else: + raise ValueError(f"Unknown browser action: {recipient}") + + return ResponseFunctionWebSearch( + id=f"ws_{random_uuid()}", + action=action, + status="completed", + type="web_search_call", + ) + + +def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]: + """Parse function calls into function tool call items.""" + function_name = recipient.split(".")[-1] + output_items = [] + for content in message.content: + random_id = random_uuid() + response_item = ResponseFunctionToolCall( + arguments=content.text, + call_id=f"call_{random_id}", + type="function_call", + name=function_name, + id=f"fc_{random_id}", + ) + output_items.append(response_item) + return output_items + + +def _parse_reasoning_content(message: Message) -> list[ResponseOutputItem]: + """Parse reasoning/analysis content into reasoning items.""" + output_items = [] + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[ + ResponseReasoningTextContent(text=content.text, type="reasoning_text") + ], + status=None, + ) + output_items.append(reasoning_item) + return output_items + + +def _parse_final_message(message: Message) -> ResponseOutputItem: + """Parse final channel messages into output message items.""" + contents = [] + for content in message.content: + output_text = ResponseOutputText( + text=content.text, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + contents.append(output_text) + return ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=contents, + role=message.author.role, + status="completed", + type="message", + ) + + def parse_output_message(message: Message) -> list[ResponseOutputItem]: """ Parse a Harmony message into a list of output response items. @@ -340,119 +439,38 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: output_items: list[ResponseOutputItem] = [] recipient = message.recipient + + # Browser tool calls if recipient is not None and recipient.startswith("browser."): - if len(message.content) != 1: - raise ValueError("Invalid number of contents in browser message") - content = message.content[0] - # We do not need to check the VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY - # env variable since if it is not set, we are certain the json is valid - # The use of Actions for web search will be removed entirely in - # the future, so this is only necessary temporarily - try: - browser_call = json.loads(content.text) - except json.JSONDecodeError: - # If the content is not valid JSON, then it was - # caught and retried by vLLM, which means we - # need to make note of that so the user is aware - json_retry_output_message = ( - f"Invalid JSON args, caught and retried: {content.text}" - ) - browser_call = { - "query": json_retry_output_message, - "url": json_retry_output_message, - "pattern": json_retry_output_message, - } - # TODO: translate to url properly! - if recipient == "browser.search": - action = ActionSearch( - query=f"cursor:{browser_call.get('query', '')}", type="search" - ) - elif recipient == "browser.open": - action = ActionOpenPage( - url=f"cursor:{browser_call.get('url', '')}", type="open_page" - ) - elif recipient == "browser.find": - action = ActionFind( - pattern=browser_call["pattern"], - url=f"cursor:{browser_call.get('url', '')}", - type="find", - ) - else: - raise ValueError(f"Unknown browser action: {recipient}") - web_search_item = ResponseFunctionWebSearch( - id=f"ws_{random_uuid()}", - action=action, - status="completed", - type="web_search_call", - ) - output_items.append(web_search_item) + output_items.append(_parse_browser_tool_call(message, recipient)) + + # Analysis channel (reasoning/chain-of-thought) elif message.channel == "analysis": - for content in message.content: - reasoning_item = ResponseReasoningItem( - id=f"rs_{random_uuid()}", - summary=[], - type="reasoning", - content=[ - ResponseReasoningTextContent( - text=content.text, type="reasoning_text" - ) - ], - status=None, - ) - output_items.append(reasoning_item) + output_items.extend(_parse_reasoning_content(message)) + + # Commentary channel elif message.channel == "commentary": + # Function calls if recipient is not None and recipient.startswith("functions."): - function_name = recipient.split(".")[-1] - for content in message.content: - random_id = random_uuid() - response_item = ResponseFunctionToolCall( - arguments=content.text, - call_id=f"call_{random_id}", - type="function_call", - name=function_name, - id=f"fc_{random_id}", - ) - output_items.append(response_item) + output_items.extend(_parse_function_call(message, recipient)) + + # Built-in tools on commentary channel are treated as reasoning for now elif recipient is not None and ( recipient.startswith("python") or recipient.startswith("browser") or recipient.startswith("container") ): - for content in message.content: - reasoning_item = ResponseReasoningItem( - id=f"rs_{random_uuid()}", - summary=[], - type="reasoning", - content=[ - ResponseReasoningTextContent( - text=content.text, type="reasoning_text" - ) - ], - status=None, - ) - output_items.append(reasoning_item) + output_items.extend(_parse_reasoning_content(message)) else: raise ValueError(f"Unknown recipient: {recipient}") + + # Final output message elif message.channel == "final": - contents = [] - for content in message.content: - output_text = ResponseOutputText( - text=content.text, - annotations=[], # TODO - type="output_text", - logprobs=None, # TODO - ) - contents.append(output_text) - text_item = ResponseOutputMessage( - id=f"msg_{random_uuid()}", - content=contents, - role=message.author.role, - status="completed", - type="message", - ) - output_items.append(text_item) + output_items.append(_parse_final_message(message)) + else: raise ValueError(f"Unknown channel: {message.channel}") + return output_items diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c121fa71f0196..481a47a97f7d4 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -834,7 +834,6 @@ class LLM: conversation, mm_data, mm_uuids = parse_chat_messages( msgs, model_config, - tokenizer, content_format=resolved_content_format, ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 67291f45a9251..bfa98f29a064b 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -105,7 +105,7 @@ from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.sampling_params import BeamSearchParams, SamplingParams -from vllm.tokenizers import MistralTokenizer, TokenizerLike +from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike from vllm.tracing import ( contains_trace_headers, extract_trace_headers, @@ -1088,11 +1088,6 @@ class OpenAIServing: Sequence[RequestPrompt], list[EngineTokensPrompt], ]: - if tokenizer is None: - raise ValueError( - "Unable to get tokenizer because `skip_tokenizer_init=True`" - ) - model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( @@ -1105,7 +1100,6 @@ class OpenAIServing: conversation, mm_data_future, mm_uuids = parse_chat_messages_futures( messages, model_config, - tokenizer, content_format=resolved_content_format, ) @@ -1128,6 +1122,13 @@ class OpenAIServing: messages=messages, **_chat_template_kwargs, ) + elif isinstance(tokenizer, DeepseekV32Tokenizer): + request_prompt = tokenizer.apply_chat_template( + conversation=conversation, + messages=messages, + model_config=model_config, + **_chat_template_kwargs, + ) else: request_prompt = apply_hf_chat_template( tokenizer=tokenizer, diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 89e439dd53f5f..ed43ea7eec82f 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -30,6 +30,10 @@ _TOOL_PARSERS_TO_REGISTER = { "deepseekv31_tool_parser", "DeepSeekV31ToolParser", ), + "deepseek_v32": ( + "deepseekv32_tool_parser", + "DeepSeekV32ToolParser", + ), "ernie45": ( "ernie45_tool_parser", "Ernie45ToolParser", diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv32_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv32_tool_parser.py new file mode 100644 index 0000000000000..4973deb7cefa8 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv32_tool_parser.py @@ -0,0 +1,591 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import uuid +from collections.abc import Sequence +from typing import Any + +import regex as re + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike + +logger = init_logger(__name__) + + +class DeepSeekV32ToolParser(ToolParser): + """ + example tool call content: + <|DSML|function_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">杭州 + <|DSML|parameter name="date" string="true">2024-01-16 + + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="location" string="true">北京 + <|DSML|parameter name="date" string="true">2024-01-16 + + + """ + + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + + self.prev_tool_call_arr: list[dict] = [] + + # Sentinel tokens + self.dsml_token: str = "|DSML|" + self.dsml_start_check: str = "<" + self.dsml_token + self.tool_call_start_token: str = "<|DSML|function_calls>" + self.tool_call_end_token: str = "" + self.invoke_start_prefix: str = "<|DSML|invoke name=" + self.invoke_end_token: str = "" + self.parameter_prefix: str = "<|DSML|parameter name=" + self.parameter_end_token: str = "" + + # Streaming state variables + self.current_tool_name_sent: bool = False + # Override base class type - we use string IDs for tool calls + self.current_tool_id: str | None = None # type: ignore + self.streamed_args_for_tool: list[str] = [] + self.is_tool_call_started: bool = False + self.failed_count: int = 0 + + # Initialize streaming state variables + self.current_tool_index: int = 0 + self.invoke_index: int = 0 + self.header_sent: bool = False + self.current_function_name: str | None = None + self.current_param_name: str | None = None + self.current_param_value: str = "" + self.param_count: int = 0 + self.in_param: bool = False + self.in_function: bool = False + self.json_started: bool = False + self.json_closed: bool = False + self.accumulated_params: dict = {} + self.streaming_request: ChatCompletionRequest | None = None + + # Enhanced streaming state - reset for each new message + self._reset_streaming_state() + + # Regex patterns for complete parsing + self.tool_call_complete_regex = re.compile( + r"<|DSML|function_calls>(.*?)", re.DOTALL + ) + self.invoke_complete_regex = re.compile( + r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)', re.DOTALL + ) + self.parameter_complete_regex = re.compile( + r'<|DSML|parameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)', + re.DOTALL, + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + logger.debug( + "vLLM Successfully import tool parser %s !", self.__class__.__name__ + ) + + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.invoke_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_id = None + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.json_started = False + self.json_closed = False + # Store accumulated parameters for type conversion + self.accumulated_params = {} + self.streaming_request = None + # Clear previous tool call history to avoid state pollution + self.prev_tool_call_arr.clear() + + def _parse_invoke_params(self, invoke_str: str) -> dict | None: + param_dict = dict() + for param_name, param_val in self.parameter_complete_regex.findall(invoke_str): + param_dict[param_name] = param_val + return param_dict + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """Extract tool calls from complete model output (non-streaming).""" + # Quick check + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + tool_calls = [] + + # Find all complete tool_call blocks + for tool_call_match in self.tool_call_complete_regex.findall(model_output): + # Find all invokes within this tool_call + for invoke_name, invoke_content in self.invoke_complete_regex.findall( + tool_call_match + ): + param_dict = self._parse_invoke_params(invoke_content) + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=invoke_name, + arguments=json.dumps(param_dict, ensure_ascii=False), + ), + ) + ) + + if not tool_calls: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + # Extract content before first tool call + first_tool_idx = model_output.find(self.tool_call_start_token) + content = model_output[:first_tool_idx] if first_tool_idx > 0 else None + + return ExtractedToolCallInformation( + tools_called=True, tool_calls=tool_calls, content=content + ) + + except Exception: + logger.exception("Error extracting tool calls") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + def _extract_name(self, name_str: str) -> str: + """Extract name from quoted string.""" + name_str = name_str.strip() + if ( + name_str.startswith('"') + and name_str.endswith('"') + or name_str.startswith("'") + and name_str.endswith("'") + ): + return name_str[1:-1] + return name_str + + def _extract_param_name(self, input_str: str) -> str: + """Extract param name""" + start = input_str.find('"') + 1 + end = input_str.find('"', start) + return input_str[start:end] if start > 0 and end > start else input_str + + def _convert_param_value(self, value: str, param_type: str) -> Any: + """Convert parameter value to the correct type.""" + if value.lower() == "null": + return None + + param_type = param_type.lower() + if param_type in ["string", "str", "text"]: + return value + elif param_type in ["integer", "int"]: + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ["number", "float"]: + try: + val = float(value) + return val if val != int(val) else int(val) + except (ValueError, TypeError): + return value + elif param_type in ["boolean", "bool"]: + return value.lower() in ["true", "1"] + elif param_type in ["object", "array"]: + try: + return json.loads(value) + except json.JSONDecodeError: + return value + else: + # Try JSON parse first, fallback to string + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], # pylint: disable=unused-argument + current_token_ids: Sequence[int], # pylint: disable=unused-argument + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + """Extract tool calls from streaming model output.""" + + # Store request for type conversion + if not previous_text: + self._reset_streaming_state() + self.streaming_request = request + + # If no delta text, return None unless it's an EOS token after tools + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + if delta_token_ids: + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text) + ) + + # If we have completed tool calls and populated prev_tool_call_arr + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: + # Check if all tool calls are closed + open_calls = current_text.count( + self.tool_call_start_token + ) - current_text.count(self.tool_call_end_token) + if open_calls == 0: + # Return empty delta for finish_reason processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + invoke_ends = current_text.count(self.invoke_end_token) + if invoke_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + self.in_function = False # Now we can safely set this to False + self.accumulated_params = {} + # Continue processing next tool + return None + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if self.dsml_token in current_text: + self.is_tool_call_started = True + # Return any content before the tool call + if self.dsml_start_check in delta_text: + content_before = delta_text[ + : delta_text.index(self.dsml_start_check) + ] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if ( + current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == "" + ): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + if delta_text.endswith("<"): + return DeltaMessage(content=delta_text[:-1]) + if previous_text and previous_text.endswith("<"): + return DeltaMessage(content="<" + delta_text) + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + invoke_starts_count = current_text.count(self.invoke_start_prefix) + if self.current_tool_index >= invoke_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # Find the current tool call portion + invoke_start_positions: list[int] = [] + idx = 0 + while True: + idx = current_text.find(self.invoke_start_prefix, idx) + if idx == -1: + break + invoke_start_positions.append(idx) + idx += len(self.invoke_start_prefix) + + if self.current_tool_index >= len(invoke_start_positions): + # No more tool calls to process yet + return None + + invoke_start_idx = invoke_start_positions[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx) + if invoke_end_idx == -1: + tool_text = current_text[invoke_start_idx:] + else: + tool_text = current_text[ + invoke_start_idx : invoke_end_idx + len(self.invoke_end_token) + ] + + # Looking for function header + if not self.header_sent: + if self.invoke_start_prefix in tool_text: + func_start = tool_text.find(self.invoke_start_prefix) + len( + self.invoke_start_prefix + ) + # Find the end quote for the function name + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + function_name_raw = tool_text[func_start:func_end] + self.current_function_name = self._extract_name(function_name_raw) + self.current_tool_id = self._generate_tool_call_id() + self.header_sent = True + self.in_function = True + + # Add to prev_tool_call_arr immediately when we detect a tool call + # Each tool call should be recorded regardless of function name + # Ensure we don't add the same tool call index multiple times + if len(self.prev_tool_call_arr) <= self.current_tool_index: + self.prev_tool_call_arr.append( + { + "name": self.current_function_name, + "arguments": "{}", # Placeholder, will be updated later + } + ) + + # Send header with function info + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments="" + ), + type="function", + ) + ] + ) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if self.in_function and not self.json_started: + self.json_started = True + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ] + ) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.invoke_end_token in tool_text: + # Count total parameters in the tool text + total_param_count = tool_text.count(self.parameter_prefix) + + # Only close JSON if all parameters have been processed + if self.param_count >= total_param_count: + # Close JSON + self.json_closed = True + + # Extract complete tool call + # Find the invoke content + invoke_start = tool_text.find(self.invoke_start_prefix) + len( + self.invoke_start_prefix + ) + invoke_content_end = tool_text.find( + self.invoke_end_token, invoke_start + ) + if invoke_content_end != -1: + invoke_content = tool_text[invoke_start:invoke_content_end] + # Parse to get the complete arguments + try: + invoke_params = self._parse_invoke_params(invoke_content) + if invoke_params and self.current_tool_index < len( + self.prev_tool_call_arr + ): + # Update existing entry in prev_tool_call_arr + self.prev_tool_call_arr[self.current_tool_index][ + "arguments" + ] = json.dumps(invoke_params, ensure_ascii=False) + except Exception: + pass # Ignore parsing errors during streaming + + result = DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ] + ) + + # Reset state for next tool + self.json_closed = True + self.in_function = False + self.accumulated_params = {} + + logger.debug("[M2_STREAMING] Tool call completed") + + return result + else: + # Don't close JSON yet, continue processing parameters + return None + + # Look for parameters + # Find all parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + # Check if we should start a new parameter + if ( + not self.in_param + and self.param_count < len(param_starts) + and len(param_starts) > self.param_count + ): + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + param_name_raw = remaining[:name_end] + self.current_param_name = self._extract_param_name(param_name_raw) + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # No closing tag, look for next parameter or function end + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.invoke_end_token) + + if next_param_idx != -1 and ( + func_end_idx == -1 or next_param_idx < func_end_idx + ): + param_end_idx = next_param_idx + elif func_end_idx != -1: + param_end_idx = func_end_idx + else: + # Neither found, check if tool call is complete + if self.invoke_end_token in tool_text: + # Tool call and parameter is complete + param_end_idx = len(value_text) + else: + # Still streaming, wait for more content + return None + + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Store raw value for later processing + self.accumulated_params[self.current_param_name] = param_value + + # Get parameter configuration for type conversion + param_config = {} + if self.streaming_request and self.streaming_request.tools: + for tool in self.streaming_request.tools: + if ( + hasattr(tool, "function") + and tool.function.name == self.current_function_name + and hasattr(tool.function, "parameters") + ): + params = tool.function.parameters + if ( + isinstance(params, dict) + and "properties" in params + ): + param_config = params["properties"] + break + + # Get parameter type + param_type = "string" + if ( + self.current_param_name in param_config + and isinstance(param_config[self.current_param_name], dict) + and "type" in param_config[self.current_param_name] + ): + param_type = param_config[self.current_param_name]["type"] + + # Convert param value to appropriate type + converted_value = self._convert_param_value( + param_value, param_type + ) + + # Build JSON fragment based on the converted type + # Use json.dumps to properly serialize the value + serialized_value = json.dumps( + converted_value, ensure_ascii=False + ) + + if self.param_count == 0: + json_fragment = ( + f'"{self.current_param_name}": {serialized_value}' + ) + else: + json_fragment = ( + f', "{self.current_param_name}": {serialized_value}' + ) + + self.param_count += 1 + + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments=json_fragment), + ) + ] + ) + + return None diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index b89db60545abd..aa5089ffe84d7 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -3,12 +3,12 @@ import json from collections.abc import Sequence +from enum import Enum, auto from random import choices from string import ascii_letters, digits -import partial_json_parser +import ijson import regex as re -from partial_json_parser.core.options import Allow from pydantic import Field from vllm.entrypoints.openai.protocol import ( @@ -23,7 +23,6 @@ from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ) -from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.logger import init_logger from vllm.tokenizers import MistralTokenizer, TokenizerLike @@ -32,6 +31,22 @@ logger = init_logger(__name__) ALPHANUMERIC = ascii_letters + digits +class StreamingState(Enum): + """Enum for tracking the current streaming parsing state.""" + + WAITING_FOR_TOOL_START = auto() + WAITING_FOR_TOOL_KEY = ( + auto() + ) # waiting for the "name" or "arguments" key to be complete + PARSING_NAME = auto() + PARSING_NAME_COMPLETED = auto() + WAITING_FOR_ARGUMENTS_START = auto() + PARSING_ARGUMENTS = auto() + PARSING_ARGUMENTS_COMPLETED = auto() + TOOL_COMPLETE = auto() + ALL_TOOLS_COMPLETE = auto() + + class MistralToolCall(ToolCall): id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id()) @@ -46,8 +61,8 @@ class MistralToolCall(ToolCall): return id.isalnum() and len(id) == 9 -def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool: - return ( +def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool: + return not ( isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11 ) @@ -69,16 +84,22 @@ class MistralToolParser(ToolParser): # initialize properties used for state when parsing tool calls in # streaming mode - self.prev_tool_call_arr: list[dict] = [] self.current_tool_id: int = -1 - self.current_tool_name_sent: bool = False - self.streamed_args_for_tool: list[ - str - ] = [] # map what has been streamed for each tool so far to a list + self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START + + # For streaming pre v11 tokenizer tool calls + self.current_tool_name: str | None = None + self.current_tool_mistral_id: str | None = None + self.starting_new_tool = False + if _is_pre_v11_tokeniser(self.model_tokenizer): + self.parse_coro = ijson.parse_coro( + self.update_stream_state_pre_v11_tokenizer() + ) + self.bot_token = "[TOOL_CALLS]" self.bot_token_id = self.vocab.get(self.bot_token) self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) - if _is_fn_name_regex_support(self.model_tokenizer): + if not _is_pre_v11_tokeniser(self.model_tokenizer): self.fn_name_regex = re.compile( r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL ) @@ -131,18 +152,19 @@ class MistralToolParser(ToolParser): # jsons is difficult try: if self.fn_name_regex: - matches = self.fn_name_regex.findall(tool_content) - function_call_arr = [] - for match in matches: - fn_name = match[0] - args = match[1] + for single_tool_content in model_output.split(self.bot_token): + matches = self.fn_name_regex.findall(single_tool_content) - # fn_name is encoded outside serialized json dump - # only arguments are serialized - function_call_arr.append( - {"name": fn_name, "arguments": json.loads(args)} - ) + for match in matches: + fn_name = match[0] + args = match[1] + + # fn_name is encoded outside serialized json dump + # only arguments are serialized + function_call_arr.append( + {"name": fn_name, "arguments": json.loads(args)} + ) else: function_call_arr = json.loads(tool_content) except json.JSONDecodeError: @@ -193,198 +215,372 @@ class MistralToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: - # if the tool call token is not in the tokens generated so far, append - # output to contents since it's not a tool - if self.bot_token not in current_text: + if self.bot_token_id not in current_token_ids: + # if the tool call token is not in the tokens generated so far, + # append output to contents since it's not a tool return DeltaMessage(content=delta_text) - # if the tool call token ID IS in the tokens generated so far, that + # if the tool call token IS in the tokens generated so far, that # means we're parsing as tool calls now - - # handle if we detected the BOT token which means the start of tool - # calling - if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1: - # if it's the only token, return None, so we don't send a chat - # completion any don't send a control token - return None - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR try: - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[-1] - - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array - try: - tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags + if _is_pre_v11_tokeniser(self.model_tokenizer): + return self._extract_tool_calls_streaming_pre_v11_tokenizer( + delta_text=delta_text, + delta_token_ids=delta_token_ids, ) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug("not enough tokens to parse into JSON yet") - return None - - # select as the current tool call the one we're on the state at - - current_tool_call: dict = ( - tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} - ) - - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return None - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif ( - len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 - ): - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - diff: str | None = current_tool_call.get("arguments") - - if diff: - diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], "" - ) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta - - # case: update an existing tool - this is handled below - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name - ).model_dump(exclude_none=True), - ) - ] - ) - self.current_tool_name_sent = True - else: - delta = None - - # now we know we're on the same tool call and we're streaming - # arguments else: - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments" + return self._extract_tool_calls_streaming( + delta_text=delta_text, delta_token_ids=delta_token_ids ) - cur_arguments = current_tool_call.get("arguments") - - new_text = delta_text.replace("'", '"') - if '"}' in new_text: - new_text = new_text[: new_text.rindex('"}')] - - if not cur_arguments and not prev_arguments: - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments reset mid-arguments" - ) - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[ - :-2 - ] - logger.debug("finding %s in %s", new_text, cur_arguments_json) - - if new_text not in cur_arguments_json: - return None - arguments_delta = cur_arguments_json[ - : cur_arguments_json.rindex(new_text) + len(new_text) - ] - logger.debug( - "First tokens in arguments received: %s", arguments_delta - ) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, ensure_ascii=False) - logger.debug( - "Searching for diff between \n%s\n%s", - cur_args_json, - prev_args_json, - ) - - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json - ) - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None - - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta - except Exception: logger.exception("Error trying to handle streaming tool call.") - logger.debug( - "Skipping chunk as a result of tool streaming extraction error" - ) return None + + def _extract_tool_calls_streaming( + self, + delta_text: str, + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extracts tool calls for Mistral models + doing tool calls of the following format: + `[TOOL_CALLS]add{"a": 3.5, "b": 4}` + """ + additional_content: str = "" + if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: + # this is the first tool call + assert self.bot_token_id in delta_token_ids + if not delta_text.startswith(self.bot_token): + additional_content += delta_text.split(self.bot_token)[0] + delta_text = self.bot_token + "".join( + delta_text.split(self.bot_token)[1:] + ) + + delta_tool_calls = self._generate_delta_tool_call(delta_text) + if not additional_content and len(delta_tool_calls) == 0: + if self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS, + StreamingState.PARSING_ARGUMENTS_COMPLETED, + StreamingState.TOOL_COMPLETE, + StreamingState.ALL_TOOLS_COMPLETE, + ]: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage() + else: + # return None when the tool is not likely to be finished + # This can occur when the name is being parsed for example + # and we wait for the name to be complete + # before sending the function name + return None + + delta = DeltaMessage() + if additional_content: + delta.content = additional_content + if len(delta_tool_calls) > 0: + delta.tool_calls = delta_tool_calls + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining its final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if delta_tool_calls and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + return delta + + def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]: + if delta_text == "" or delta_text is None: + return [] + delta_function_name = None + tool_id = None + if self.streaming_state not in [ + StreamingState.PARSING_NAME, + StreamingState.PARSING_ARGUMENTS, + ] and delta_text.startswith(self.bot_token): + self.current_tool_id += 1 + self.streaming_state = StreamingState.PARSING_NAME + delta_text = delta_text.replace(self.bot_token, "", 1) + if self.streaming_state == StreamingState.PARSING_NAME: + if self.current_tool_name is None: + self.current_tool_name = "" + # The name stops where the arguments start + # And the arguments start with the `{` char + if "{" in delta_text: + tool_id = MistralToolCall.generate_random_id() + delta_function_name = delta_text.split("{")[0] + self.current_tool_name += delta_function_name + delta_text = delta_text[len(delta_function_name) :] + self.streaming_state = StreamingState.PARSING_ARGUMENTS + else: + # we want to send the tool name once it's complete + self.current_tool_name += delta_text + return [] + if self.streaming_state == StreamingState.PARSING_ARGUMENTS: + next_function_text = None + if self.bot_token in delta_text: + # current tool call is over + delta_arguments = "" + delta_arguments += delta_text.split(self.bot_token)[0] + next_function_text = delta_text[len(delta_arguments) :] + self.streaming_state = StreamingState.TOOL_COMPLETE + else: + delta_arguments = delta_text + ret = [] + if self.current_tool_name or delta_arguments: + ret += [ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=self.current_tool_name, arguments=delta_arguments + ).model_dump(exclude_none=True), + ) + ] + self.current_tool_name = None + if next_function_text: + ret += self._generate_delta_tool_call(next_function_text) + return ret + # Should not happen + return [] + + @ijson.coroutine + def update_stream_state_pre_v11_tokenizer(self): + while True: + (prefix, event, value) = yield + + if prefix == "item" and event == "start_map": + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + if prefix == "item" and event == "map_key" and value == "name": + self.streaming_state = StreamingState.PARSING_NAME + if prefix == "item.name" and event == "string": + self.current_tool_name = value + self.streaming_state = StreamingState.PARSING_NAME_COMPLETED + if prefix == "item" and event == "map_key" and value == "arguments": + self.streaming_state = StreamingState.WAITING_FOR_ARGUMENTS_START + if prefix == "item.arguments" and event == "start_map": + self.streaming_state = StreamingState.PARSING_ARGUMENTS + if prefix == "item.arguments" and event == "end_map": + self.streaming_state = StreamingState.PARSING_ARGUMENTS_COMPLETED + if prefix == "item" and event == "end_map": + self.streaming_state = StreamingState.TOOL_COMPLETE + if prefix == "" and event == "end_array": + self.streaming_state = StreamingState.ALL_TOOLS_COMPLETE + + def _extract_tool_calls_streaming_pre_v11_tokenizer( + self, + delta_text: str, + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """ + Extracts tool calls for Mistral models + doing tool calls of the following format: + `[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}` + """ + assert self.parse_coro is not None + content = None + delta_tool_calls: list[DeltaToolCall] = [] + current_tool_call: DeltaToolCall = DeltaToolCall( + index=self.current_tool_id, type="function" + ) + current_tool_call_modified = False + if self.bot_token_id in delta_token_ids: + # this is the first tool call + if not delta_text.startswith(self.bot_token): + content = delta_text.split(self.bot_token)[0] + delta_text = "".join(delta_text.split(self.bot_token)[1:]) + + # Cut smartly the delta text to catch the ijson events + # as ijson does not give us the index in the text at each event. + # We need to cut so that we know + # where in the text the events are emitted from. + while len(delta_text) > 0: + streaming_state_before_parse = self.streaming_state + + if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + ) + elif self.streaming_state == StreamingState.WAITING_FOR_TOOL_KEY: + # Wait until another key is sent + # or the current tool is completed + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_colon=1, + stop_after_opening_curly_braces=1, + # if the tool ends, we want to separate + # at the start of the next tool + ) + elif self.streaming_state == StreamingState.PARSING_NAME: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_comma=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.WAITING_FOR_ARGUMENTS_START: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + ) + elif self.streaming_state == StreamingState.PARSING_ARGUMENTS: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_closing_curly_braces=1, + # we could be more clever + # by listening to item.arguments.* start_map events + # and know how many curly braces we can allow + ) + elif self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS_COMPLETED, + StreamingState.PARSING_NAME_COMPLETED, + ]: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_closing_curly_braces=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.TOOL_COMPLETE: + delta_to_be_parsed, delta_text = self._split_delta( + delta_text=delta_text, + stop_after_opening_curly_braces=1, + stop_after_closing_brackets=1, + ) + elif self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE: + content = delta_text + delta_text = "" + else: + delta_to_be_parsed = delta_text + delta_text = "" + + if self.streaming_state != StreamingState.ALL_TOOLS_COMPLETE: + self.parse_coro.send(delta_to_be_parsed.encode("utf-8")) + + # Given the parsed text and the possible streaming state change, + # let's add to the tool delta + if ( + (streaming_state_before_parse != self.streaming_state) + and streaming_state_before_parse + in [StreamingState.WAITING_FOR_TOOL_START, StreamingState.TOOL_COMPLETE] + and self.streaming_state + not in [ + StreamingState.ALL_TOOLS_COMPLETE, + StreamingState.TOOL_COMPLETE, + StreamingState.WAITING_FOR_TOOL_START, + ] + ): + # starting a new tool call + if current_tool_call_modified: + if self.current_tool_mistral_id is not None: + current_tool_call.id = self.current_tool_mistral_id + self.current_tool_mistral_id = None + delta_tool_calls.append(current_tool_call) + current_tool_call_modified = False + self.current_tool_id += 1 + self.current_tool_mistral_id = MistralToolCall.generate_random_id() + current_tool_call = DeltaToolCall( + index=self.current_tool_id, + type="function", + ) + if current_tool_call.function is None: + current_tool_call.function = DeltaFunctionCall() + + if self.current_tool_name is not None: + # we have the complete tool name + current_tool_call_modified = True + current_tool_call.function.name = self.current_tool_name + self.current_tool_name = None + if self.streaming_state == StreamingState.PARSING_NAME_COMPLETED: + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + if self.streaming_state in [ + StreamingState.PARSING_ARGUMENTS, + StreamingState.PARSING_ARGUMENTS_COMPLETED, + ]: + if self.streaming_state == StreamingState.PARSING_ARGUMENTS_COMPLETED: + self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY + # the delta_to_be_parsed is part of arguments. + current_tool_call_modified = True + if current_tool_call.function.arguments is None: + current_tool_call.function.arguments = delta_to_be_parsed + else: + current_tool_call.function.arguments += delta_to_be_parsed + if streaming_state_before_parse != StreamingState.PARSING_ARGUMENTS: + # It's the first chunk of arg. let's lstrip it + current_tool_call.function.arguments = ( + current_tool_call.function.arguments.lstrip() + ) + + if current_tool_call_modified: + if self.current_tool_mistral_id is not None: + current_tool_call.id = self.current_tool_mistral_id + self.current_tool_mistral_id = None + delta_tool_calls.append(current_tool_call) + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if delta_tool_calls and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if content or len(delta_tool_calls) > 0: + delta_message = DeltaMessage() + if content: + delta_message.content = content + if len(delta_tool_calls) > 0: + delta_message.tool_calls = delta_tool_calls + return delta_message + else: + if self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE: + return DeltaMessage() + else: + return None + + def _split_delta( + self, + delta_text: str, + stop_after_quotes: int = -1, + stop_after_opening_curly_braces: int = -1, + stop_after_closing_curly_braces: int = -1, + stop_after_closing_brackets: int = -1, + stop_after_colon: int = -1, + stop_after_comma=-1, + ) -> tuple[str, str]: + delta_to_be_parsed = "" + for i, c in enumerate(delta_text): + if c in ['"', "'"]: + delta_to_be_parsed += c + stop_after_quotes -= 1 + if stop_after_quotes == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "{": + delta_to_be_parsed += c + stop_after_opening_curly_braces -= 1 + if stop_after_opening_curly_braces == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "}": + delta_to_be_parsed += c + stop_after_closing_curly_braces -= 1 + if stop_after_closing_curly_braces == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == "]": + delta_to_be_parsed += c + stop_after_closing_brackets -= 1 + if stop_after_closing_brackets == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == ":": + delta_to_be_parsed += c + stop_after_colon -= 1 + if stop_after_colon == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + elif c == ",": + delta_to_be_parsed += c + stop_after_comma -= 1 + if stop_after_comma == 0: + return (delta_to_be_parsed, delta_text[i + 1 :]) + else: + delta_to_be_parsed += c + + return (delta_to_be_parsed, "") diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 8819c85af9a26..072ddd4c90b16 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -89,12 +89,10 @@ def parse_score_data( data_1: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam, model_config: ModelConfig, - tokenizer: TokenizerLike, ) -> tuple[str, str, MultiModalDataDict | None]: - mm_tracker = MultiModalItemTracker(model_config, tokenizer) + mm_tracker = MultiModalItemTracker(model_config) content_1 = _parse_score_content(data_1, mm_tracker) - content_2 = _parse_score_content(data_2, mm_tracker) def ensure_str(content: _ContentPart | None) -> str: @@ -188,7 +186,6 @@ def get_score_prompt( data_1, data_2, model_config, - tokenizer, ) from vllm.model_executor.model_loader import get_model_cls diff --git a/vllm/entrypoints/serve/instrumentator/metrics.py b/vllm/entrypoints/serve/instrumentator/metrics.py index efe0c63a90714..5231451383a2b 100644 --- a/vllm/entrypoints/serve/instrumentator/metrics.py +++ b/vllm/entrypoints/serve/instrumentator/metrics.py @@ -2,9 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import re - import prometheus_client +import regex as re from fastapi import FastAPI, Response from prometheus_client import make_asgi_app from prometheus_fastapi_instrumentator import Instrumentator diff --git a/vllm/envs.py b/vllm/envs.py index 8b954fa14f28c..37711dece9abc 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: VLLM_DISABLE_FLASHINFER_PREFILL: bool = False VLLM_DO_NOT_TRACK: bool = False VLLM_USAGE_SOURCE: str = "" - VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_CONFIGURE_LOGGING: bool = True VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" VLLM_LOGGING_STREAM: str = "ext://sys.stdout" @@ -78,8 +78,8 @@ if TYPE_CHECKING: MAX_JOBS: str | None = None NVCC_THREADS: str | None = None VLLM_USE_PRECOMPILED: bool = False + VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX: bool = False VLLM_DOCKER_BUILD_CONTEXT: bool = False - VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False CMAKE_BUILD_TYPE: Literal["Debug", "Release", "RelWithDebInfo"] | None = None VERBOSE: bool = False @@ -175,6 +175,7 @@ if TYPE_CHECKING: VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 + VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998 VLLM_ALL2ALL_BACKEND: Literal[ "naive", "pplx", @@ -197,6 +198,7 @@ if TYPE_CHECKING: VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: int | None = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 480 + VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT: int = 480 VLLM_USE_CUDNN_PREFILL: bool = False VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL: bool = False VLLM_ENABLE_CUDAGRAPH_GC: bool = False @@ -462,17 +464,16 @@ environment_variables: dict[str, Callable[[], Any]] = { .lower() in ("1", "true") or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + # If set, skip adding +precompiled suffix to version string + "VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX": lambda: bool( + int(os.environ.get("VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX", "0")) + ), # Used to mark that setup.py is running in a Docker build context, # in order to force the use of precompiled binaries. "VLLM_DOCKER_BUILD_CONTEXT": lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "") .strip() .lower() in ("1", "true"), - # Whether to force using nightly wheel in python build. - # This is used for testing the nightly wheel in python build. - "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": lambda: bool( - int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) - ), # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" @@ -618,7 +619,9 @@ environment_variables: dict[str, Callable[[], Any]] = { # If set to 0, vllm will not configure logging # If set to 1, vllm will configure logging using the default configuration # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH - "VLLM_CONFIGURE_LOGGING": lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_CONFIGURE_LOGGING": lambda: bool( + int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) + ), "VLLM_LOGGING_CONFIG_PATH": lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), # this is used for configuring the default logging level "VLLM_LOGGING_LEVEL": lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO").upper(), @@ -1259,6 +1262,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NIXL_SIDE_CHANNEL_PORT": lambda: int( os.getenv("VLLM_NIXL_SIDE_CHANNEL_PORT", "5600") ), + # Port used for Mooncake handshake between remote agents. + "VLLM_MOONCAKE_BOOTSTRAP_PORT": lambda: int( + os.getenv("VLLM_MOONCAKE_BOOTSTRAP_PORT", "8998") + ), # all2all backend for vllm's expert parallel communication # Available options: # - "naive": naive all2all implementation using broadcasts @@ -1368,6 +1375,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": lambda: int( os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "480") ), + # Timeout (in seconds) for MooncakeConnector in PD disaggregated setup. + "VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT": lambda: int( + os.getenv("VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT", "480") + ), # Controls whether or not to use cudnn prefill "VLLM_USE_CUDNN_PREFILL": lambda: bool( int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0")) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index d0081ff72eba6..b961a35d3b84f 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -661,9 +661,9 @@ class LoRAModelManager: def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): assert isinstance(module, BaseLayerWithLoRA), ( - f"Module {module_name} must be a BaseLayerWithLoRA instance," + f"Module {module_name} must be a BaseLayerWithLoRA instance, " + f"got {type(module)}" ) - f" got {type(module)}" self.modules[module_name] = module def create_dummy_lora( diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 413ee8ecbbf96..34383cdf1767c 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -96,10 +96,14 @@ def _fused_moe_lora_kernel( slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) lora_id = tl.load(lora_ids + lora_idx) - moe_enabled = tl.load(adapter_enabled + lora_id) - if lora_id == -1 or moe_enabled == 0: + + if lora_id == -1: # Early exit for the no-lora case. return + moe_enabled = tl.load(adapter_enabled + lora_id) + if moe_enabled == 0: + # Early exit for the no moe lora case. + return max_loras = tl.num_programs(axis=2) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 669abcb3d6ff1..9103e84aa7057 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -60,9 +60,6 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts, - ) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassBatchedExpertsFp8, CutlassExpertsFp8, @@ -98,7 +95,6 @@ if HAS_TRITON: "DeepGemmExperts", "BatchedDeepGemmExperts", "TritonOrDeepGemmExperts", - "BatchedTritonOrDeepGemmExperts", ] else: # Some model classes directly use the custom ops. Add placeholders diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py deleted file mode 100644 index e69e9fd307aeb..0000000000000 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ /dev/null @@ -1,180 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts, -) -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts -from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout - - -class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - max_num_tokens: int, - num_dispatchers: int, - quant_config: FusedMoEQuantConfig, - allow_deep_gemm: bool = False, - ): - super().__init__(quant_config) - - self.batched_triton_experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - quant_config=self.quant_config, - ) - - self.allow_deep_gemm = ( - allow_deep_gemm - and self.quant_config.use_fp8_w8a8 - and self.block_shape == get_mk_alignment_for_contiguous_layout() - ) - - self.batched_deep_gemm_experts = ( - BatchedDeepGemmExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - quant_config=self.quant_config, - ) - if self.allow_deep_gemm - else None - ) - - assert ( - self.batched_deep_gemm_experts is not None - or self.batched_triton_experts is not None - ) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.batched_triton_experts is not None: - assert ( - self.batched_deep_gemm_experts is None - or self.batched_deep_gemm_experts.activation_formats - == self.batched_triton_experts.activation_formats - ) - return self.batched_triton_experts.activation_formats - else: - assert self.batched_deep_gemm_experts is not None - return self.batched_deep_gemm_experts.activation_formats - - def supports_chunking(self) -> bool: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - return (bdge is None or bdge.supports_chunking()) and ( - bte is None or bte.supports_chunking() - ) - - def supports_expert_map(self) -> bool: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - return (bdge is None or bdge.supports_expert_map()) and ( - bte is None or bte.supports_expert_map() - ) - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None - bte_war = bte.finalize_weight_and_reduce_impl() if bte else None - is_bdge_war = bdge_war is not None - is_bte_war = bte_war is not None - - if is_bdge_war and is_bte_war: - assert bdge_war == bte_war, ( - "Both implementations should agree on WeightAndReduce impls. " - f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}" - ) - - if bdge_war is not None: - return bdge_war - - assert bte_war is not None - return bte_war - - def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: - return act_dtype - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_metadata: mk.ExpertTokensMetadata | None, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - # Note: the deep gemm workspaces are strictly larger than the triton - # workspaces so we can be pessimistic here and allocate for DeepGemm - # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm: - assert self.batched_deep_gemm_experts is not None - return self.batched_deep_gemm_experts.workspace_shapes( - M, - N, - K, - topk, - global_num_experts, - local_num_experts, - expert_tokens_metadata, - ) - else: - assert self.batched_triton_experts is not None - return self.batched_triton_experts.workspace_shapes( - M, - N, - K, - topk, - global_num_experts, - local_num_experts, - expert_tokens_metadata, - ) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool, - ): - experts = ( - self.batched_deep_gemm_experts - if self.allow_deep_gemm - else self.batched_triton_experts - ) - assert experts is not None - experts.apply( - output, - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - activation, - global_num_experts, - expert_map, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_tokens_meta, - apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 86cdd25f2c873..9f47e692d5ae2 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch -from tqdm import tqdm -import vllm.envs as env import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( @@ -25,12 +23,12 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, + silu_mul_per_token_group_quant_fp8_colmajor, ) from vllm.utils.deep_gemm import ( get_mk_alignment_for_contiguous_layout, m_grouped_fp8_gemm_nt_contiguous, ) -from vllm.utils.func_utils import run_once from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) @@ -108,70 +106,6 @@ def _valid_deep_gemm( return True -@run_once -def warmup_deepgemm_gg_contiguous_kernels( - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - num_topk: int, -): - """ - DeepGemm JITs the grouped-gemm kernels. The JIT'ing happens based on the - input tensor shapes. In this function, we construct all possible input - tensor shapes so all the kernels are JIT'ed and cached. - Note that this warmup is expected to happen during the model profile - call and not during actual model inference. - """ - - assert w1.size(0) == w2.size(0), "w1 and w2 must have the same number of experts" - - block_m = get_mk_alignment_for_contiguous_layout()[0] - num_experts = w1.size(0) - device = w1.device - - # This is the maximum GroupedGemm M size that we expect to run - # the grouped_gemm with. - MAX_M = compute_aligned_M( - env.VLLM_FUSED_MOE_CHUNK_SIZE, - num_topk, - num_experts, - block_m, - expert_tokens_meta=None, - ) - # Distribute expert-ids evenly. - MAX_BLOCKS = MAX_M // block_m - expert_ids_block = torch.randint( - low=0, high=num_experts, size=(MAX_BLOCKS,), device=device, dtype=torch.int32 - ) - expert_ids = torch.repeat_interleave(expert_ids_block, block_m, dim=0) - - def _warmup(w: torch.Tensor, w_scale: torch.Tensor): - _, n, k = w.size() - a1q = torch.empty((MAX_M, k), device=device).to(torch.float8_e4m3fn) - a1q_scales = torch.empty( - (MAX_M, k // block_m), device=device, dtype=torch.float32 - ) - out = torch.empty((MAX_M, n), device=device, dtype=torch.bfloat16) - - pbar = tqdm( - total=MAX_BLOCKS, desc=f"DeepGemmExperts GEMM warmup (MAX_M={MAX_M})" - ) - num_tokens = MAX_M - while num_tokens > 0: - m_grouped_fp8_gemm_nt_contiguous( - (a1q[:num_tokens], a1q_scales[:num_tokens]), - (w, w_scale), - out[:num_tokens], - expert_ids[:num_tokens], - ) - pbar.update(1) - num_tokens = num_tokens - block_m - - _warmup(w1, w1_scale) - _warmup(w2, w2_scale) - - class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def __init__(self, quant_config: FusedMoEQuantConfig): super().__init__(quant_config) @@ -215,11 +149,32 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ) assert M_sum % block_m == 0 - workspace1 = (M_sum, N) - workspace2 = (M_sum, max(N // 2, K)) + workspace1 = (M_sum, max(N // 2, K)) + workspace2 = (M_sum, max(N, K)) output = (M, K) return (workspace1, workspace2, output) + def _act_mul_quant( + self, input: torch.Tensor, output: torch.Tensor, activation: str + ) -> tuple[torch.Tensor, torch.Tensor]: + if activation == "silu": + return silu_mul_per_token_group_quant_fp8_colmajor( + input=input, output=output + ) + else: + # This is a fallback path. If we find ourselves using any activation other + # than silu, we should add that activation to + # silu_mul_per_token_group_quant_fp8_colmajor kernel as it is much faster. + M_sum, N = input.size() + act_out = torch.empty( + (M_sum, N // 2), dtype=input.dtype, device=input.device + ) + self.activation(activation, act_out, input) + assert self.block_shape is not None + return per_token_group_quant_fp8( + act_out, self.block_shape[1], column_major_scales=True, out_q=output + ) + def apply( self, output: torch.Tensor, @@ -261,14 +216,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta=expert_tokens_meta, ) - a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M_sum, K)) - mm1_out = _resize_cache(workspace13, (M_sum, N)) - act_out = _resize_cache(workspace2, (M_sum, N // 2)) - quant_out = _resize_cache( - workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + a1q_perm = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K) ) - mm2_out = _resize_cache(workspace2, (M_sum, K)) - a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( aq=a1q, aq_scale=a1q_scale, @@ -280,17 +230,19 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ) assert a1q.size(0) == M_sum + mm1_out = _resize_cache(workspace2, (M_sum, N)) m_grouped_fp8_gemm_nt_contiguous( (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids ) - self.activation(activation, act_out, mm1_out.view(-1, N)) - - a2q_scale: torch.Tensor | None = None - a2q, a2q_scale = per_token_group_quant_fp8( - act_out, self.block_shape[1], column_major_scales=True, out_q=quant_out + quant_out = _resize_cache( + workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2) + ) + a2q, a2q_scale = self._act_mul_quant( + input=mm1_out.view(-1, N), output=quant_out, activation=activation ) + mm2_out = _resize_cache(workspace2, (M_sum, K)) m_grouped_fp8_gemm_nt_contiguous( (a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids ) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 7dd02e32ff211..d1942689d7f5c 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -64,42 +64,6 @@ class PoolingParamsUpdate: params.requires_token_ids = self.requires_token_ids -def get_prompt_lens( - hidden_states: torch.Tensor | list[torch.Tensor], - pooling_metadata: PoolingMetadata, -) -> torch.Tensor: - return pooling_metadata.prompt_lens - - -def get_prompt_token_ids(pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: - assert pooling_metadata.prompt_token_ids is not None, ( - "Please set `requires_token_ids=True` in `get_pooling_updates`" - ) - - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] - - -def get_pooling_params(pooling_metadata: PoolingMetadata) -> list[PoolingParams]: - pooling_params = pooling_metadata.pooling_params - return pooling_params - - -def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: - pooling_params = get_pooling_params(pooling_metadata) - - tasks: list[PoolingTask] = [ - task - for pooling_param in pooling_params - if (task := pooling_param.task) is not None - ] - assert len(pooling_params) == len(tasks) - - return tasks - - def get_classification_activation_function(config: PretrainedConfig): # Implement alignment with transformers ForSequenceClassificationLoss # https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92 @@ -163,14 +127,14 @@ class PoolingMethod(nn.Module, ABC): self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: raise NotImplementedError def forward( self, hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: pooling_cursor = pooling_metadata.pooling_cursor return self.forward_all(hidden_states, pooling_cursor) @@ -183,7 +147,7 @@ class CLSPool(PoolingMethod): self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with CLS pooling" ) @@ -199,27 +163,65 @@ class LastPool(PoolingMethod): self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: return hidden_states[pooling_cursor.last_token_indices_gpu] class AllPool(PoolingMethod): + def __init__(self): + super().__init__() + + vllm_config = get_current_vllm_config() + self.enable_chunked_prefill = ( + vllm_config.scheduler_config.enable_chunked_prefill + ) + def get_supported_tasks(self) -> Set[PoolingTask]: return {"token_embed", "token_classify"} def forward_all( - self, - hidden_states: torch.Tensor, - pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: - assert not pooling_cursor.is_partial_prefill(), ( - "partial prefill not supported with ALL pooling" + self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor + ) -> PoolerOutput: + raise NotImplementedError( + "forward_all is not implemented for AllPool. Use forward instead." ) + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooling_cursor = pooling_metadata.pooling_cursor + is_finished = pooling_cursor.is_finished() hidden_states_lst = list( hidden_states.split(pooling_cursor.num_scheduled_tokens_cpu.tolist()) ) - return [hidden_states_lst[i] for i in pooling_cursor.index] + hidden_states_lst = [hidden_states_lst[i] for i in pooling_cursor.index] + + if not self.enable_chunked_prefill: + return hidden_states_lst + + pooling_states = pooling_metadata.pooling_states + + # If chunked_prefill is enabled + # 1. first store the chunked hidden_states in pooling_states.hidden_states_cache + for p, hs_chunk in zip(pooling_states, hidden_states_lst): + p.hidden_states_cache.append(hs_chunk) + + # 2. Once prefill is finished, send hidden_states_cache to PoolerHead + output_list: PoolerOutput = [] + for p, finished in zip(pooling_states, is_finished): + if finished: + hidden_states_cache = p.hidden_states_cache + if len(hidden_states_cache) == 1: + output_list.append(hidden_states_cache[0]) + else: + output_list.append(torch.concat(hidden_states_cache, dim=0)) + p.clean() + else: + output_list.append(None) + + return output_list class MeanPool(PoolingMethod): @@ -230,7 +232,7 @@ class MeanPool(PoolingMethod): self, hidden_states: torch.Tensor, pooling_cursor: PoolingCursor, - ) -> list[torch.Tensor] | torch.Tensor: + ) -> PoolerOutput: assert not pooling_cursor.is_partial_prefill(), ( "partial prefill not supported with MEAN pooling" ) @@ -435,7 +437,7 @@ class PoolerHead(nn.Module): self, pooled_data: list[torch.Tensor] | torch.Tensor, pooling_metadata: PoolingMetadata, - ): + ) -> PoolerOutput: return self.activation(pooled_data) @@ -454,7 +456,7 @@ class EmbeddingPoolerHead(PoolerHead): self, pooled_data: list[torch.Tensor] | torch.Tensor, pooling_metadata: PoolingMetadata, - ): + ) -> PoolerOutput: if isinstance(pooled_data, list): pooled_data = torch.stack(pooled_data) # pooled_data shape: [batchsize, hidden_dimension] @@ -466,7 +468,7 @@ class EmbeddingPoolerHead(PoolerHead): pooled_data = self.projector(pooled_data) # pooled_data shape: [batchsize, embedding_dimension] - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params # for matryoshka representation dimensions_list = [pooling_param.dimensions for pooling_param in pooling_params] @@ -606,7 +608,7 @@ class ClassifierPooler(Pooler): if self.logit_bias is not None: pooled_data -= self.logit_bias - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params flags = [p.use_activation for p in pooling_params] if len(set(flags)) == 1: @@ -622,8 +624,12 @@ class ClassifierPooler(Pooler): class TokenEmbeddingPoolerHead(EmbeddingPoolerHead): def forward( - self, pooled_data: torch.Tensor, pooling_param: PoolingParams - ) -> torch.Tensor: + self, pooled_data: torch.Tensor | None, pooling_param: PoolingParams + ) -> PoolerOutput: + # for unfinished chunked prefill + if pooled_data is None: + return None + pooled_data = pooled_data.to(self.head_dtype) # pooled_data shape: [n_tokens, hidden_dimension] @@ -666,9 +672,13 @@ class TokenClassifierPoolerHead(nn.Module): def forward( self, - hidden_states: torch.Tensor, + hidden_states: torch.Tensor | None, pooling_param: PoolingParams, - ) -> torch.Tensor: + ) -> PoolerOutput: + # for unfinished chunked prefill + if hidden_states is None: + return None + hidden_states = hidden_states.to(self.head_dtype) # hidden_states shape: [n_token, hidden_size] @@ -704,7 +714,7 @@ class AllPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params assert len(pooled_data) == len(pooling_params) pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] @@ -722,17 +732,20 @@ class StepPooler(Pooler): self, hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, - ) -> torch.Tensor | list[torch.Tensor]: + ) -> PoolerOutput: pooled_data_lst = self.pooling(hidden_states, pooling_metadata) - prompt_token_ids = get_prompt_token_ids(pooling_metadata) - - pooled_data = list[torch.Tensor]() - - pooling_params = get_pooling_params(pooling_metadata) + prompt_token_ids = pooling_metadata.get_prompt_token_ids() + pooling_params = pooling_metadata.pooling_params + pooled_data: PoolerOutput = [] for data, token_id, pooling_param in zip( pooled_data_lst, prompt_token_ids, pooling_params ): + # for unfinished chunked prefill + if data is None: + pooled_data.append(data) + continue + step_tag_id = pooling_param.step_tag_id returned_token_ids = pooling_param.returned_token_ids @@ -757,7 +770,7 @@ class StepPooler(Pooler): pooling_metadata: PoolingMetadata, ) -> PoolerOutput: pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooling_params = get_pooling_params(pooling_metadata) + pooling_params = pooling_metadata.pooling_params assert len(pooled_data) == len(pooling_params) pooled_data = [self.head(d, p) for d, p in zip(pooled_data, pooling_params)] @@ -794,7 +807,7 @@ class DispatchPooler(Pooler): outputs = list[torch.Tensor]() offset = 0 - for task, group in groupby(get_tasks(pooling_metadata)): + for task, group in groupby(pooling_metadata.tasks): if not (pooler := poolers_by_task.get(task)): raise ValueError( f"Unsupported task: {task} " diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c7368bf427fe1..d7fb6d2ca367d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -90,8 +90,10 @@ from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( get_col_major_tma_aligned_tensor, + get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, ) +from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) @@ -1088,9 +1090,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): return experts - # triton path - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts, + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts, + ) + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, @@ -1098,6 +1102,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): assert not self.rocm_aiter_moe_enabled and not self.use_marlin + use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + if ( prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts @@ -1105,22 +1111,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - return BatchedTritonOrDeepGemmExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - allow_deep_gemm=( - envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM - ), + if use_deep_gemm and not has_deep_gemm(): + raise RuntimeError( + "DeepGEMM requested for MoE layer but not installed." + ) + + compatible_with_deep_gemm = ( + self.moe_quant_config.use_fp8_w8a8 + and self.moe_quant_config.block_shape + == get_mk_alignment_for_contiguous_layout() ) + + # If this MoE layer is compatible with DeepGEMM, the proper env + # vars are set and DeepGEMM is not installed, throw an error. + if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm(): + raise RuntimeError( + f"MoE layer incompatible with DeepGEMM, expected " + f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}" + f"or block_shape {self.moe_quant_config.block_shape}" + f"=={get_mk_alignment_for_contiguous_layout()}." + ) + + if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm(): + logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__) + return BatchedDeepGemmExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) + return BatchedTritonExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) return TritonOrDeepGemmExperts( self.moe_quant_config, - allow_deep_gemm=( - envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM - ), + allow_deep_gemm=use_deep_gemm, ) def get_fused_moe_quant_config( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ae63b4a767268..6e73833d1ae1c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -492,6 +492,139 @@ def _per_token_group_quant_fp8( tl.store(y_s_ptr, y_s) +@triton.jit +def _silu_mul_per_token_group_quant_fp8_colmajor( + y_ptr, # [M, N] + y_q_ptr, # [M, N // 2] + y_s_ptr, # [M, (N // 2) // GROUP_SIZE] + M, # num tokens + N, # intermediate size + # Stride + y_s_col_stride: tl.int64, + # Information for float8 + eps, + fp8_min, + fp8_max, + use_ue8m0: tl.constexpr, + # Meta-parameters + GROUP_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # TODO(varun) : Add expert_ids so we may early-exit no-op thread blocks. + """ + Each thread block (BLOCK_N) computes [BLOCK_M, GROUP_SIZE] act-mul outputs. Then + the thread block quantizes the [BLOCK_M, GROUP_SIZE] block of values and fills + the outputs tensors at the right positions. + """ + + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + N_2 = N // 2 + + m_offset = pid_m * BLOCK_M + n_offset = pid_n * BLOCK_N + if m_offset >= M: + return + + offs_n = tl.arange(0, BLOCK_N).to(tl.int64) + offs_m = tl.arange(0, BLOCK_M).to(tl.int64) + + base_y_ptr = y_ptr + m_offset * N + n_offset + + act_in_ptrs = base_y_ptr + offs_m[:, None] * N + offs_n[None, :] + + act_in = tl.load(act_in_ptrs) + mul_in = tl.load(act_in_ptrs + N_2) + + # silu & mul + act_in = act_in.to(tl.float32) + one_f32 = tl.cast(1, tl.float32) + silu_out = (act_in / (one_f32 + tl.exp(-act_in))).to(y_ptr.dtype.element_ty) + y = (silu_out * mul_in).to(tl.float32) + + # quant + _absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps) + scale_raw = _absmax / fp8_max + y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw + y_s = tl.reshape(y_s, (BLOCK_M, 1)) + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + # store y_q + base_y_q_ptr = y_q_ptr + m_offset * N_2 + n_offset + y_q_ptrs = base_y_q_ptr + offs_m[:, None] * N_2 + offs_n[None, :] + tl.store(y_q_ptrs, y_q) + + # store y_s + group_id = n_offset // GROUP_SIZE + base_y_s_ptr = y_s_ptr + group_id * y_s_col_stride + m_offset + y_s_ptrs = base_y_s_ptr + offs_m + y_s = tl.reshape(y_s, (BLOCK_M,)) + tl.store(y_s_ptrs, y_s) + + +def silu_mul_per_token_group_quant_fp8_colmajor( + input: torch.Tensor, # [M, N] + output: torch.Tensor | None = None, # [M, N // 2] + use_ue8m0: bool | None = None, + eps: float = 1e-10, +): + """ + silu+mul + block-fp8 quant with group size 128. + """ + GROUP_SIZE = 128 + assert input.ndim == 2 + if output is not None: + assert output.ndim == 2 + assert input.size(0) % GROUP_SIZE == 0 + assert input.size(1) % (GROUP_SIZE * 2) == 0 + + if use_ue8m0 is None: + use_ue8m0 = is_deep_gemm_e8m0_used() + + M, N = input.size() + N_2 = N // 2 + + if output is None: + output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device) + + output_scales = torch.empty( + ((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device + ).transpose(0, 1) + + BLOCK_M = 8 + BLOCK_N = GROUP_SIZE + assert M % BLOCK_M == 0 + assert N_2 % BLOCK_N == 0 + + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_min = finfo.min + fp8_max = finfo.max + + # Force even division so we can avoid edgecases within the kernel. + assert M % BLOCK_M == 0 + assert N_2 % BLOCK_N == 0 + grid = (M // BLOCK_M, N_2 // BLOCK_N) + + _silu_mul_per_token_group_quant_fp8_colmajor[grid]( + input, + output, + output_scales, + M, + N, + output_scales.stride(-1), + eps, + fp8_min, + fp8_max, + use_ue8m0, + GROUP_SIZE, + BLOCK_M, + BLOCK_N, + ) + + return output, output_scales + + @triton.jit def _per_token_group_quant_fp8_colmajor( # Pointers to inputs and output diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 550e8b014d5e7..2aba626a7c737 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -14,8 +14,6 @@ from vllm.model_executor.layers.pooler import ( PoolerHead, PoolerNormalize, PoolingParamsUpdate, - get_prompt_lens, - get_prompt_token_ids, ) from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.tasks import PoolingTask @@ -153,11 +151,11 @@ class GritLMMeanPool(nn.Module): hidden_states: torch.Tensor | list[torch.Tensor], pooling_metadata: PoolingMetadata, ) -> list[torch.Tensor] | torch.Tensor: - prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + prompt_lens = pooling_metadata.prompt_lens instr_lens = torch.tensor( [ self._get_instruction_len(token_ids.cpu().numpy()) - for token_ids in get_prompt_token_ids(pooling_metadata) + for token_ids in pooling_metadata.get_prompt_token_ids() ], device="cpu", ) diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index 6537b6df876a9..5aef09ca9c256 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -62,6 +62,7 @@ from vllm.multimodal.inputs import ( from vllm.multimodal.parse import ( DictEmbeddingItems, ImageSize, + ModalityDataItems, MultiModalDataItems, MultiModalDataParser, ) @@ -570,7 +571,7 @@ class HunYuanVLMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: dict[str, torch.Tensor] | ModalityData[ImageItem], - ): + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 8817601558148..09acf8372e168 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1000,7 +1000,7 @@ class KeyeMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: dict[str, torch.Tensor] | ModalityData[ImageItem], - ) -> ModalityDataItems[Any, Any]: + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -1017,7 +1017,7 @@ class KeyeMultiModalDataParser(MultiModalDataParser): def _parse_video_data( self, data: dict[str, torch.Tensor] | ModalityData[VideoItem], - ) -> ModalityDataItems[Any, Any]: + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py index 124e9c2afa217..2b04e3bd4b75b 100644 --- a/vllm/model_executor/models/keye_vl1_5.py +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -333,7 +333,7 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: dict[str, torch.Tensor] | ModalityData[ImageItem], - ) -> ModalityDataItems[Any, Any]: + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, @@ -350,7 +350,7 @@ class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): def _parse_video_data( self, data: dict[str, torch.Tensor] | ModalityData[VideoItem], - ) -> ModalityDataItems[Any, Any]: + ) -> ModalityDataItems[Any, Any] | None: if isinstance(data, dict): return DictEmbeddingItems( data, diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py deleted file mode 100644 index 0f1230a55bae6..0000000000000 --- a/vllm/model_executor/models/phi4_multimodal.py +++ /dev/null @@ -1,1447 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math -from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, TypeAlias - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers import ( - BatchFeature, - Phi4MultimodalAudioConfig, - Phi4MultimodalConfig, - Phi4MultimodalFeatureExtractor, - Phi4MultimodalImageProcessorFast, -) -from transformers import Phi4MultimodalProcessor as Phi4MMProcessor -from transformers.models.phi4_multimodal.modeling_phi4_multimodal import ( - Phi4MultimodalAudioConvModule, - Phi4MultimodalAudioNemoConvSubsampling, - Phi4MultimodalAudioRelativeAttentionBias, - adaptive_enc_mask, - unfold_tensor, -) - -from vllm.config import VllmConfig -from vllm.config.multimodal import BaseDummyOptions -from vllm.distributed import ( - divide, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - MultiModalDataDict, - MultiModalFieldConfig, - MultiModalKwargsItems, - NestedTensors, -) -from vllm.multimodal.parse import ( - AudioProcessorItems, - ImageEmbeddingItems, - ImageProcessorItems, - ImageSize, - MultiModalDataItems, - MultiModalDataParser, -) -from vllm.multimodal.processing import ( - BaseMultiModalProcessor, - BaseProcessingInfo, - PromptReplacement, - PromptUpdate, -) -from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import IntermediateTensors -from vllm.utils.tensor_schema import TensorSchema, TensorShape - -from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal -from .utils import ( - AutoWeightsLoader, - WeightsMapper, - init_vllm_registered_model, - maybe_prefix, -) - -_AUDIO_MAX_SOUNDFILE_SIZE = 241_000 - - -def _get_padding_size( - orig_width: int, orig_height: int, target_height: int, target_width: int -): - ratio_width = target_width / orig_width - ratio_height = target_height / orig_height - - if ratio_width < ratio_height: - padding_width = 0 - padding_height = target_height - int(orig_height * ratio_width) - else: - padding_width = target_width - int(orig_width * ratio_height) - padding_height = 0 - return padding_height, padding_width - - -class Phi4MMProjector(nn.Module): - def __init__(self, input_size: int, hidden_size: int): - super().__init__() - self.up = ColumnParallelLinear(input_size, hidden_size) - self.down = RowParallelLinear(hidden_size, hidden_size) - self.act = get_act_fn("gelu") - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x, _ = self.up(x) - x = self.act(x) - x, _ = self.down(x) - return x - - -class Phi4MMImageEmbedding(nn.Module): - """Image embedding.""" - - def __init__(self, config: Phi4MultimodalConfig): - super().__init__() - self.config = config - self.layer_idx = config.vision_config.feature_layer - self.crop_size = config.vision_config.crop_size - self.image_dim_out = config.vision_config.hidden_size - - n_patches = config.vision_config.image_size // config.vision_config.patch_size - if n_patches % 2 != 0: - self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) - n_patches += 1 - self.num_img_tokens = (n_patches // 2) ** 2 - - num_hidden_layers = ( - config.vision_config.num_hidden_layers + self.layer_idx + 1 - if self.layer_idx < 0 - else self.layer_idx + 1 - ) - self.img_processor = Idefics2VisionTransformer( - config.vision_config, - require_post_norm=False, - num_hidden_layers_override=num_hidden_layers, - ) - self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) - self.img_projection = Phi4MMProjector(self.image_dim_out, config.hidden_size) - self.global_img_feature_extensor = nn.Parameter( - torch.zeros([1, 1, self.image_dim_out]) - ) - self.sub_img_feature_extensor = nn.Parameter( - torch.zeros([1, 1, 1, self.image_dim_out]) - ) - - def get_img_features( - self, - img_embeds: torch.FloatTensor, - attention_mask: torch.Tensor | None = None, - ) -> torch.FloatTensor: - img_feature = self.img_processor( - img_embeds, patch_attention_mask=attention_mask - ) - - patch_feature = img_feature - # reshape to 2D tensor - width = int(math.sqrt(patch_feature.size(1))) - patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) - # convert to NCHW - patch_feature = patch_feature.permute(0, 3, 1, 2) - if getattr(self, "img_processor_padding", None) is not None: - patch_feature = self.img_processor_padding(patch_feature) - patch_feature = self.image_token_compression(patch_feature) - # convert to NHWC - patch_feature = patch_feature.permute(0, 2, 3, 1) - patch_feature = patch_feature.view( - -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) - ) - return patch_feature - - def forward( - self, - image_pixel_values: torch.FloatTensor, - image_sizes: torch.Tensor | None = None, - image_attention_mask: torch.Tensor | None = None, - ) -> torch.FloatTensor: - image_pixel_values = image_pixel_values.to( - self.img_processor.embeddings.patch_embedding.weight.dtype - ) - - target_device = self.img_projection.up.bias.device - target_dtype = self.img_projection.up.bias.dtype - - batch_size = image_pixel_values.shape[0] - - img_features = self.get_img_features( - image_pixel_values.flatten(0, 1), - attention_mask=image_attention_mask.flatten(0, 1).to( - dtype=bool, device=target_device - ), - ) - base_feat_size = int(np.sqrt(img_features.shape[1])) - img_features = img_features.view( - batch_size, -1, base_feat_size**2, self.image_dim_out - ) - image_sizes = image_sizes.view(-1, 2) - - output_imgs = [] - for idx in range(batch_size): - height, width = image_sizes[idx] - height_ratio = height // self.crop_size - width_ratio = width // self.crop_size - area_ratio = height_ratio * width_ratio - - global_img = img_features[idx, :1] - global_img = global_img.reshape( - 1, base_feat_size, base_feat_size, self.image_dim_out - ).contiguous() - temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, base_feat_size, 1, 1 - ) - global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape( - 1, -1, self.image_dim_out - ) - - sub_img = img_features[idx, 1:] - sub_img = sub_img[:area_ratio] - sub_img = ( - sub_img.reshape( - height_ratio, - width_ratio, - base_feat_size, - base_feat_size, - self.image_dim_out, - ) - .transpose(1, 2) - .reshape( - 1, - height_ratio * base_feat_size, - width_ratio * base_feat_size, - self.image_dim_out, - ) - .contiguous() - ) - - if image_attention_mask is not None: - reshaped_image_attention_mask = ( - image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2] - .reshape(height_ratio, width_ratio, base_feat_size, base_feat_size) - .transpose(1, 2) - .reshape( - 1, height_ratio * base_feat_size, width_ratio * base_feat_size - ) - ) - useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) - useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) - sub_img = sub_img[:, :useful_height, :useful_width] - temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, useful_height, 1, 1 - ) - else: - temporary_extensor = self.sub_img_feature_extensor.repeat( - 1, height_ratio * base_feat_size, 1, 1 - ) - - sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape( - 1, -1, self.image_dim_out - ) - - # Merge global and sub - output_imgs.append( - torch.cat( - [sub_img, self.global_img_feature_extensor, global_img], dim=1 - ) - ) - - img_set_tensor = [] - for output_img in output_imgs: - output_img = output_img.to(device=target_device, dtype=target_dtype) - img_feature_proj = self.img_projection(output_img) - img_set_tensor.append(img_feature_proj.flatten(0, 1)) - - return img_set_tensor - - -class Phi4MultimodalAudioMLP(nn.Module): - def __init__( - self, - config: Phi4MultimodalAudioConfig, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.layer_norm = nn.LayerNorm(config.hidden_size) - self.act_fn = MulAndSilu() - self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, - [config.intermediate_size] * 2, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) - self.down_proj = RowParallelLinear( - config.intermediate_size, - config.hidden_size, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.down_proj", - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.layer_norm(hidden_states) - hidden_states, _ = self.gate_up_proj(hidden_states) - hidden_states = self.act_fn(hidden_states) - hidden_states, _ = self.down_proj(hidden_states) - return hidden_states - - -class Phi4MultimodalAudioAttention(nn.Module): - def __init__( - self, - config: Phi4MultimodalAudioConfig, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.total_num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.total_num_heads - if self.head_dim * self.total_num_heads != self.embed_dim: - raise ValueError( - "embed_dim must be divisible by num_heads " - f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - - self.qkv_proj = QKVParallelLinear( - hidden_size=self.embed_dim, - head_size=self.head_dim, - total_num_heads=self.total_num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - - self.o_proj = RowParallelLinear( - input_size=self.embed_dim, - output_size=self.embed_dim, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.num_heads = divide(self.total_num_heads, self.tp_size) - - def split_attn_mask(self, attention_mask: torch.Tensor) -> torch.Tensor: - start_idx = self.num_heads * self.tp_rank - end_idx = self.num_heads * (self.tp_rank + 1) - return attention_mask[:, start_idx:end_idx] - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - qkv_states, _ = self.qkv_proj(hidden_states) - query, key, value = qkv_states.chunk(3, dim=-1) - - bsz, seq_len, _ = query.size() - query = query.view(bsz, seq_len, self.num_heads, self.head_dim) - key = key.view(bsz, seq_len, self.num_heads, self.head_dim) - value = value.view(bsz, seq_len, self.num_heads, self.head_dim) - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - - attention_mask = self.split_attn_mask(attention_mask) - out = F.scaled_dot_product_attention( - query, - key, - value, - scale=self.scale, - attn_mask=attention_mask, - ) - out = out.transpose(1, 2).reshape(bsz, seq_len, -1) - - attn_output, _ = self.o_proj(out) - - return attn_output - - -class Phi4MultimodalAudioConformerEncoderLayer(nn.Module): - def __init__(self, config: Phi4MultimodalAudioConfig): - super().__init__() - - self.feed_forward_in = Phi4MultimodalAudioMLP(config) - self.self_attn = Phi4MultimodalAudioAttention(config) - self.conv = Phi4MultimodalAudioConvModule(config) - self.feed_forward_out = Phi4MultimodalAudioMLP(config) - self.layer_norm_att = nn.LayerNorm(config.hidden_size) - self.layer_norm = nn.LayerNorm(config.hidden_size) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states + 0.5 * self.feed_forward_in(hidden_states) - hidden_states = self.layer_norm_att(residual) - - hidden_states = residual + self.self_attn(hidden_states, attention_mask) - hidden_states = hidden_states + self.conv(hidden_states) - hidden_states = hidden_states + 0.5 * self.feed_forward_out(hidden_states) - - out = self.layer_norm(hidden_states) - - return out - - -class Phi4MMAudioMeanVarianceNormLayer(nn.Module): - """Mean/variance normalization layer. - - Will subtract mean and multiply input by inverted standard deviation. - Typically used as a very first layer in a model. - - Args: - config: [Phi4MultimodalAudioConfig](https://huggingface.co/docs/transformers/model_doc/phi4_multimodal#transformers.Phi4MultimodalAudioConfig) - object containing model parameters. - """ - - def __init__(self, config: Phi4MultimodalAudioConfig): - super().__init__() - self.global_mean = nn.Parameter(torch.zeros(config.input_size)) - self.global_invstd = nn.Parameter(torch.ones(config.input_size)) - - def forward(self, input_: torch.Tensor) -> torch.Tensor: - """MeanVarianceNormLayer Forward - - Args: - input_: torch.Tensor - input tensor. - """ - return (input_ - self.global_mean) * self.global_invstd - - -class Phi4MultimodalAudioModel(nn.Module): - def __init__(self, config: Phi4MultimodalAudioConfig): - super().__init__() - self.config = config - - self.encoder_embedding = Phi4MMAudioMeanVarianceNormLayer(config) - self.embed = Phi4MultimodalAudioNemoConvSubsampling(config) - self.relative_attention_bias_layer = Phi4MultimodalAudioRelativeAttentionBias( - config - ) - self.encoders = nn.ModuleList( - [ - Phi4MultimodalAudioConformerEncoderLayer(config) - for _ in range(config.num_blocks) - ] - ) - - def _streaming_mask( - self, - seq_len: int, - batch_size: int, - chunk_size: int, - left_chunk: int, - ): - # Create mask matrix for streaming - # S stores start index. if chunksize is 18, s is [0,18,36,....] - chunk_start_idx = np.arange(0, seq_len, chunk_size) - - enc_streaming_mask = ( - adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk) - .unsqueeze(0) - .expand([batch_size, -1, -1]) - ) - return enc_streaming_mask - - def forward_embeddings( - self, - hidden_states: torch.Tensor, - masks: torch.Tensor, - ): - """Forwarding the inputs through the top embedding layers""" - seq_len = math.ceil(hidden_states.shape[1] / self.config.time_reduction) - if seq_len <= 0: - raise ValueError( - f"Sequence length after time reduction is invalid: {seq_len}." - "Your input feature is too short." - ) - - batch_size = hidden_states.shape[0] - - enc_streaming_mask = self._streaming_mask( - seq_len, batch_size, self.config.chunk_size, self.config.left_chunk - ) - enc_streaming_mask = enc_streaming_mask.to(hidden_states.device) - - hidden_states, masks = self.embed(hidden_states, masks) - - streaming_mask = enc_streaming_mask - if streaming_mask is not None and masks is not None: - hs_mask = masks & streaming_mask - elif masks is not None: - hs_mask = masks - else: - hs_mask = streaming_mask - - return hidden_states, hs_mask, masks - - def calculate_hs_mask( - self, hidden_states: torch.Tensor, device: torch.device, mask: torch.Tensor - ): - max_audio_length = hidden_states.shape[1] - batch_size = hidden_states.shape[0] - enc_streaming_mask = self._streaming_mask( - max_audio_length, batch_size, self.config.chunk_size, self.config.left_chunk - ) - enc_streaming_mask = enc_streaming_mask.to(device) - if mask is None: - return enc_streaming_mask - - feature_lens = mask.sum(1) - padding_length = feature_lens - pad_mask = torch.arange(0, max_audio_length, device=device).expand( - padding_length.size(0), -1 - ) < padding_length.unsqueeze(1) - pad_mask = pad_mask.unsqueeze(1) - pad_mask = pad_mask & enc_streaming_mask - return pad_mask - - def forward(self, hidden_states: torch.Tensor, mask: torch.Tensor | None = None): - hidden_states = self.encoder_embedding(hidden_states) - hidden_states, hs_mask, mask = self.forward_embeddings(hidden_states, mask) - - unfolded = False - bs, seq_len, _ = hidden_states.shape - max_seq_len = 500 # maximum position for absolute positional encoding - if seq_len > max_seq_len: - # audio sequence is longer than max_seq_len, - # unfold it into chunks of max_seq_len - unfolded = True - # the unfold op will drop residual frames, - # pad it to the multiple of max_seq_len - if seq_len % max_seq_len > 0: - chunk_pad_size = max_seq_len - (seq_len % max_seq_len) - else: - chunk_pad_size = 0 - if chunk_pad_size > 0: - hidden_states_pad = F.pad( - hidden_states, (0, 0, 0, chunk_pad_size), "constant", 0 - ) - hidden_states = hidden_states_pad.to(hidden_states.device) - - hidden_states = unfold_tensor(hidden_states, max_seq_len) - masks_unfold = None - if mask is not None: - # revise hs_mask here because the previous calculated hs_mask - # did not consider extra pad - subsampled_pad_mask = mask.squeeze(1) # [bz, subsampled_unmask_seq_len] - extra_padded_subsamlped_pad_mask = F.pad( - subsampled_pad_mask, (0, chunk_pad_size), "constant", False - ) # extra padding to the pad mask - extra_padded_subsamlped_pad_mask = ( - extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() - ) - masks_unfold = unfold_tensor( - extra_padded_subsamlped_pad_mask, max_seq_len - ) # unfold the pad mask like we did to the input tensor - masks_unfold = masks_unfold.squeeze( - -1 - ).bool() # unfold op does not support bool tensor - hs_mask = self.calculate_hs_mask( - hidden_states, hidden_states.device, masks_unfold - ) # calculate hs_mask based on the unfolded pad mask - - relative_attention_bias = self.relative_attention_bias_layer(hidden_states) - attention_mask = hs_mask.unsqueeze(1) + relative_attention_bias - - for layer in self.encoders: - hidden_states = layer(hidden_states, attention_mask) - - if unfolded: - embed_dim = hidden_states.shape[-1] - hidden_states = hidden_states.reshape(bs, -1, embed_dim) - # if we ever padded before unfolding, we need to remove the padding - if chunk_pad_size > 0: - hidden_states = hidden_states[:, :-chunk_pad_size, :] - - return hidden_states - - -class Phi4MMAudioEmbedding(nn.Module): - def __init__(self, config: Phi4MultimodalConfig): - super().__init__() - self.config = config - self.layer_idx = config.audio_config.feature_layer - - self.encoder = Phi4MultimodalAudioModel(config.audio_config) - - audio_config = config.audio_config - proj_input_size = audio_config.hidden_size * audio_config.downsample_rate - self.vision_speech_projection = Phi4MMProjector( - proj_input_size, config.hidden_size - ) - self.speech_projection = Phi4MMProjector(proj_input_size, config.hidden_size) - - def get_projection( - self, - audio_projection_mode: Literal["speech", "vision"], - ) -> Phi4MMProjector: - if audio_projection_mode == "speech": - return self.speech_projection - elif audio_projection_mode == "vision": - return self.vision_speech_projection - - def forward( - self, - audio_input_features: torch.FloatTensor, - audio_embed_sizes=None, - audio_attention_mask=None, - audio_projection_mode="speech", - ) -> torch.FloatTensor: - audio_projection = self.get_projection(audio_projection_mode) - - target_device = audio_projection.up.bias.device - target_dtype = audio_projection.up.bias.dtype - - audio_input_features = audio_input_features.to( - device=target_device, dtype=target_dtype - ) - - audio_encoder_hidden_states = self.encoder( - audio_input_features, audio_attention_mask - ) - audio_embeds = audio_projection(audio_encoder_hidden_states) - - return audio_embeds.flatten(0, 1) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Phi4MMImagePixelInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of images - - p: Number of patches (1 + num_patches) - - c: Number of channels (3) - - h: Height of each image patch - - w: Width of each image patch - - nc: Number of crops - - H_mask: Height of attention mask - - W_mask: Width of attention mask - """ - - type: Literal["pixel_values"] - - pixel_values: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape( - "bn", "p", 3, "h", "w", dynamic_dims={"p"} - ), # may be different per batch and image - ] - - image_sizes: Annotated[ - torch.Tensor, - TensorShape("bn", 2), # (height, width) - ] - - num_img_tokens: Annotated[ - list[int], - TensorShape("bn"), - ] - - image_attention_mask: Annotated[ - torch.Tensor, - TensorShape("bn", "nc", 32, 32), # H_mask, W_mask - ] - - -class Phi4MMImageEmbeddingInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of images - - f: Image feature size - - h: Hidden size (must match language model backbone) - """ - - type: Literal["image_embeds"] - - data: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "f", "h"), - ] - - -class Phi4MMAudioFeatureInputs(TensorSchema): - """ - Dimensions: - - bn: Batch size * number of audios - - f: Number of Mel filterbank bins (80) - - t: Time frames (M) - """ - - type: Literal["audio_features"] - - audio_features: Annotated[ - torch.Tensor | list[torch.Tensor], - TensorShape("bn", "t", 80, dynamic_dims={"t"}), - ] - - -class Phi4MMAudioEmbeddingInputs(TensorSchema): - """ - Dimensions: - - b: Batch size - - n: Number of audios - - f: Audio feature size - - h: Hidden size (must match language model backbone) - """ - - type: Literal["audio_embeds"] - - data: Annotated[ - NestedTensors, - TensorShape("b", "n", "f", "h"), - ] - - -Phi4MMImageInput: TypeAlias = Phi4MMImagePixelInputs | Phi4MMImageEmbeddingInputs -Phi4MMAudioInputs: TypeAlias = Phi4MMAudioFeatureInputs | Phi4MMAudioEmbeddingInputs - - -def cat_with_pad(tensors, dim, padding_value=0): - """ - cat along dim, while pad to max for all other dims - """ - ndim = tensors[0].dim() - assert all(t.dim() == ndim for t in tensors[1:]), ( - "All tensors must have the same number of dimensions" - ) - - out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] - out_size[dim] = sum(t.shape[dim] for t in tensors) - output = tensors[0].new_full(out_size, padding_value) - - index = 0 - for t in tensors: - # Create a slice list where every dimension except dim is full slice - slices = [slice(0, t.shape[d]) for d in range(ndim)] - # Update only the concat dimension slice - slices[dim] = slice(index, index + t.shape[dim]) - - output[slices] = t - index += t.shape[dim] - - return output - - -class Phi4MMProcessingInfo(BaseProcessingInfo): - def get_hf_config(self) -> Phi4MultimodalConfig: - return self.ctx.get_hf_config(Phi4MultimodalConfig) - - def get_hf_processor(self, **kwargs: object) -> Phi4MMProcessor: - return self.ctx.get_hf_processor(Phi4MMProcessor, **kwargs) - - def get_feature_extractor(self, **kwargs: object) -> Phi4MultimodalFeatureExtractor: - return self.get_hf_processor(**kwargs).audio_processor - - def get_image_processor( - self, - processor: Phi4MMProcessor | None = None, - ) -> Phi4MultimodalImageProcessorFast: - if processor is None: - processor = self.get_hf_processor() - return processor.image_processor - - def get_dynamic_hd( - self, - processor: Phi4MMProcessor | None = None, - ) -> int: - return self.get_image_processor(processor).dynamic_hd - - def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"audio": None, "image": None} - - def _find_target_aspect_ratio( - self, - orig_width: int, - orig_height: int, - image_size: int, - max_num: int, - min_num: int, - ): - w_crop_num = math.ceil(orig_width / float(image_size)) - h_crop_num = math.ceil(orig_height / float(image_size)) - if w_crop_num * h_crop_num > max_num: - aspect_ratio = orig_width / orig_height - - # calculate the existing image aspect ratio - target_ratios = set( - (i, j) - for i in range(1, max_num + 1) - for j in range(1, max_num + 1) - if i * j <= max_num and i * j >= min_num - ) - target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) - - # find the closest aspect ratio to the target - image_processor = self.get_image_processor() - target_aspect_ratio = image_processor.find_closest_aspect_ratio( - aspect_ratio, - target_ratios, - orig_width, - orig_height, - image_size, - ) - - # calculate the target width and height - target_width = image_size * target_aspect_ratio[0] - target_height = image_size * target_aspect_ratio[1] - else: - target_width = image_size * w_crop_num - target_height = image_size * h_crop_num - target_aspect_ratio = (w_crop_num, h_crop_num) - return target_aspect_ratio, target_height, target_width - - def _compute_num_image_tokens( - self, - orig_width: int, - orig_height: int, - dynamic_hd_size: int, - vit_image_size: int, - vit_patch_size: int, - token_compression_factor: int = 2, - ): - """ - compute the number of tokens an image is expected to take up considering - the image encoder architecture and exclude output features containing - only padding pixels - - for siglip, vit_image_size=448, vit_patch_size=14, so output will be - 32x32 feature map - NOTE right now, Phi4MM uses hard-coded token_compression_factor=2 - """ - assert vit_image_size % vit_patch_size == 0, ( - "vit_image_size must be divisible by vit_patch_size" - ) - assert vit_image_size // vit_patch_size % token_compression_factor == 0, ( - "vit_image_size // vit_patch_size must be divisible by " - "token_compression_factor" - ) - - target_aspect_ratio, target_height, target_width = ( - self._find_target_aspect_ratio( - orig_width, orig_height, vit_image_size, dynamic_hd_size, min_num=1 - ) - ) - assert target_aspect_ratio[0] * vit_image_size == target_width, ( - f"{target_aspect_ratio[0]} * {vit_image_size} != {target_width}" - ) - assert target_aspect_ratio[1] * vit_image_size == target_height, ( - f"{target_aspect_ratio[1]} * {vit_image_size} != {target_height}" - ) - assert ( - target_height % vit_image_size == 0 and target_width % vit_image_size == 0 - ) - - padding_height, padding_width = _get_padding_size( - orig_width, orig_height, target_height, target_width - ) - assert padding_width == 0 or padding_height == 0, ( - "padding_width or padding_height must be 0" - ) - - target_feat_width = target_width // vit_patch_size - target_feat_height = target_height // vit_patch_size - if padding_width >= vit_patch_size: - assert padding_height == 0, "padding_height not 0" - non_pad_feat_width = target_feat_width - math.floor( - padding_width / vit_patch_size - ) - non_pad_feat_height = target_feat_height - elif padding_height >= vit_patch_size: - assert padding_width == 0, "padding_width not 0" - non_pad_feat_height = target_feat_height - math.floor( - padding_height / vit_patch_size - ) - non_pad_feat_width = target_feat_width - else: - # small padding shorter than a vit patch - non_pad_feat_width = target_feat_width - non_pad_feat_height = target_feat_height - - feat_width = non_pad_feat_width // token_compression_factor - feat_height = non_pad_feat_height // token_compression_factor - # NOTE it's possible that the non-padding feature is not divisible - if non_pad_feat_width % token_compression_factor != 0: - feat_width += 1 - if non_pad_feat_height % token_compression_factor != 0: - feat_height += 1 - num_hd_patch_tokens = feat_width * feat_height - num_hd_newline_tokens = feat_height - vit_feature_size = vit_image_size // vit_patch_size - num_global_image_tokens = (vit_feature_size // token_compression_factor) ** 2 - num_sep_tokens = 1 - num_global_image_newline_tokens = vit_feature_size // token_compression_factor - - return ( - num_global_image_tokens - + num_sep_tokens - + num_hd_patch_tokens - + num_hd_newline_tokens - + num_global_image_newline_tokens - ) - - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - processor: Phi4MMProcessor | None = None, - ) -> int: - hf_config = self.get_hf_config() - vision_config = hf_config.vision_config - vit_image_size = vision_config.image_size - vit_patch_size = vision_config.patch_size - - dynamic_hd_size = self.get_dynamic_hd(processor=processor) - - # we use default `token_compression_factor=2`, - # since it's not in HF vision config. - image_num_tokens = self._compute_num_image_tokens( - image_width, - image_height, - dynamic_hd_size=dynamic_hd_size, - vit_image_size=vit_image_size, - vit_patch_size=vit_patch_size, - ) - - return image_num_tokens - - def get_image_size_with_most_features( - self, - processor: Phi4MMProcessor | None = None, - ) -> ImageSize: - vit_image_size = self.get_hf_config().vision_config.image_size - - max_side = vit_image_size * self.get_dynamic_hd(processor=processor) - return ImageSize(height=max_side, width=vit_image_size) - - def get_audio_num_frames(self, audio_len: int, sr: float) -> int: - """ - Compute the output size of the `extract_features` method. - - Args: - audio_len (int): Length of the input waveform in samples. - sr (float): Sampling rate of the waveform, either 16000 or 8000. - - Returns: - tuple (int, int): Output size as (T, D), where: - T: Number of time frames. - D: Number of Mel filterbank bins (80). - """ - - # Resample to 16000 or 8000 if needed - if sr > 16000: - audio_len //= sr // 16000 - elif 8000 <= sr < 16000: - # We'll resample to 16K from 8K - audio_len *= 2 - elif sr < 8000: - raise RuntimeError(f"Unsupported sample rate {sr}") - - # Spectrogram parameters for 16 kHz - win_length = 400 # Frame length in samples - hop_length = 160 # Frame shift in samples - - # Calculate number of frames (T) - num_frames = (audio_len - win_length) // hop_length + 1 - if num_frames < 1: - raise ValueError("Waveform too short for given parameters.") - - # Return time frames (T) - return num_frames - - def _compute_audio_embed_size(self, audio_frames: int) -> int: - """ - Compute the size of audio embeddings from the number of audio frames. - """ - # `_compute_audio_embed_size` in audio_processor use torch for - # computation, therefore we re-implement it to use pythonic - # numeric computation to avoid extra tensor conversion. - audio_processor = self.get_feature_extractor() - audio_compression_rate = audio_processor.audio_compression_rate - audio_downsample_rate = audio_processor.audio_downsample_rate - - integer = audio_frames // audio_compression_rate - remainder = audio_frames % audio_compression_rate - result = integer + int(remainder > 0) - - integer = result // audio_downsample_rate - remainder = result % audio_downsample_rate - result = integer + int(remainder > 0) # qformer compression - - return result - - -class Phi4MMDummyInputsBuilder(BaseDummyInputsBuilder[Phi4MMProcessingInfo]): - def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: - num_audios = mm_counts.get("audio", 0) - num_images = mm_counts.get("image", 0) - - tokenizer = self.info.get_tokenizer() - image_tokens: str = tokenizer.image_token * num_images - audio_tokens: str = tokenizer.audio_token * num_audios - - return image_tokens + audio_tokens - - def get_dummy_mm_data( - self, - seq_len: int, - mm_counts: Mapping[str, int], - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalDataDict: - num_audios = mm_counts.get("audio", 0) - num_images = mm_counts.get("image", 0) - - target_width, target_height = self.info.get_image_size_with_most_features() - - image_overrides = mm_options.get("image") if mm_options else None - audio_overrides = mm_options.get("audio") if mm_options else None - - mm_data = { - "image": self._get_dummy_images( - width=target_width, - height=target_height, - num_images=num_images, - overrides=image_overrides, - ), - "audio": self._get_dummy_audios( - length=_AUDIO_MAX_SOUNDFILE_SIZE, - num_audios=num_audios, - overrides=audio_overrides, - ), - } - - return mm_data - - -class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): - def _get_data_parser(self) -> MultiModalDataParser: - feature_extractor = self.info.get_feature_extractor() - return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) - - def _call_hf_processor( - self, - prompt: str, - mm_data: Mapping[str, object], - mm_kwargs: Mapping[str, object], - tok_kwargs: Mapping[str, object], - ) -> BatchFeature: - if not mm_data: - prompt_ids = self.info.get_tokenizer().encode(prompt) - prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) - return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - - audio_data = mm_data.pop("audios", []) - if audio_data: - mm_data["audio"] = audio_data - - processed_outputs = super()._call_hf_processor( - prompt, mm_data, mm_kwargs, tok_kwargs - ) - - if "image_pixel_values" in processed_outputs: - num_img_tokens = [ - self.info.get_num_image_tokens( - image_width=img_size[0], image_height=img_size[1] - ) - for img_size in processed_outputs["image_sizes"] - ] - processed_outputs["num_img_tokens"] = num_img_tokens - - if audio_data: - audio_features = processed_outputs["audio_input_features"] - sr = self.info.get_feature_extractor(**mm_kwargs).sampling_rate - feature_sizes = [ - self.info.get_audio_num_frames(len(audio), sr) for audio in audio_data - ] - processed_outputs["audio_input_features"] = [ - audio_features[idx, :size] for idx, size in enumerate(feature_sizes) - ] - - return processed_outputs - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - image_pixel_values=MultiModalFieldConfig.batched("image"), - image_attention_mask=MultiModalFieldConfig.batched("image"), - image_sizes=MultiModalFieldConfig.batched("image"), - num_img_tokens=MultiModalFieldConfig.batched("image"), - audio_input_features=MultiModalFieldConfig.batched("audio"), - ) - - def _get_prompt_updates( - self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, Any], - out_mm_kwargs: MultiModalKwargsItems, - ) -> Sequence[PromptUpdate]: - tokenizer = self.info.get_tokenizer() - image_token_id: int = tokenizer.vocab[tokenizer.image_token] - audio_token_id: int = tokenizer.vocab[tokenizer.audio_token] - - hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - audio_processor = self.info.get_feature_extractor(**hf_processor_mm_kwargs) - - def get_image_replacement_phi4mm(item_idx: int): - images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems) - ) - - if isinstance(images, ImageEmbeddingItems): - num_image_tokens = images.get_feature_size(item_idx) - else: - image_size = images.get_image_size(item_idx) - num_image_tokens = self.info.get_num_image_tokens( - image_width=image_size.width, - image_height=image_size.height, - processor=hf_processor, - ) - - return [image_token_id] * num_image_tokens - - def get_audio_replacement_phi4mm(item_idx: int): - audios = mm_items.get_items("audio", AudioProcessorItems) - # TODO(Isotr0py): support embedding inputs - audio_len = audios.get_audio_length(item_idx) - audio_frames = self.info.get_audio_num_frames( - audio_len, audio_processor.sampling_rate - ) - audio_embed_size = self.info._compute_audio_embed_size(audio_frames) - - return [audio_token_id] * audio_embed_size - - return [ - PromptReplacement( - modality="audio", - target=[audio_token_id], - replacement=get_audio_replacement_phi4mm, - ), - PromptReplacement( - modality="image", - target=[image_token_id], - replacement=get_image_replacement_phi4mm, - ), - ] - - -@MULTIMODAL_REGISTRY.register_processor( - Phi4MMMultiModalProcessor, - info=Phi4MMProcessingInfo, - dummy_inputs=Phi4MMDummyInputsBuilder, -) -class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): - """ - Implements the Phi-4-multimodal-instruct model in vLLM. - """ - - merge_by_field_config = True - - packed_modules_mapping = { - "qkv_proj": [ - "qkv_proj", - ], - "gate_up_proj": [ - "gate_up_proj", - ], - } - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_prefix={ - # Multimodal embedding - "model.embed_tokens_extend.": "", - # LLM backbone - "model.": "language_model.model.", - }, - orig_to_new_substr={ - # projection - ".img_projection_": ".img_projection.", - ".up_proj_for_speech.": ".speech_projection.up.", - ".up_proj_for_vision_speech.": ".vision_speech_projection.up.", - ".down_proj_for_speech.": ".speech_projection.down.", - ".down_proj_for_vision_speech.": ".vision_speech_projection.down.", - }, - ) - - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("image"): - return "<|image|>" - if modality.startswith("audio"): - return "<|audio|>" - - raise ValueError("Only image or audio modality is supported") - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - multimodal_config = vllm_config.model_config.multimodal_config - self.config = config - self.multimodal_config = multimodal_config - - # TODO: Optionally initializes these for supporting input embeddings. - self.image_embed = Phi4MMImageEmbedding( - config, - # prefix=maybe_prefix(prefix, "image_embed"), - ) - self.audio_embed = Phi4MMAudioEmbedding( - config, - # prefix=maybe_prefix(prefix, "audio_embed"), - ) - - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - architectures=["Phi3ForCausalLM"], - ) - - self.make_empty_intermediate_tensors = ( - self.language_model.make_empty_intermediate_tensors - ) - - def _parse_and_validate_audio_input( - self, **kwargs: object - ) -> Phi4MMAudioInputs | None: - """ - Parse and validate the audio input to the model. This handles both - audio features and audio embeddings, but only the former is used for - now. - - Args: - kwargs (object): Keyword arguments. - - Returns: - Optional[Phi4MMAudioInputs]: Parsed and validated audio inputs. - """ - audio_features = kwargs.pop("audio_input_features", None) - audio_embeds = kwargs.pop("audio_embeds", None) - - if audio_features is None and audio_embeds is None: - return None - - if audio_features is not None: - return Phi4MMAudioFeatureInputs( - type="audio_features", - audio_features=audio_features, - ) - - if audio_embeds is not None: - return Phi4MMAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) - - raise AssertionError("This line should be unreachable.") - - def _process_audio_input( - self, audio_input: Phi4MMAudioInputs, audio_projection_mode: str - ) -> NestedTensors: - """ - Create the audio embeddings from the audio input, where the audio input - is pairs of audio features and audio embed lengths. The audio input is - created by `input_mapper_for_phi4mm_audio`. - - Args: - audio_input (Phi4MMAudioInputs): Audio input. - - Returns: - NestedTensors: Audio embeddings - """ - if audio_input["type"] == "audio_embeds": - return audio_input["data"] - - audio_features = audio_input["audio_features"] - # (e.g. multiple examples) and the second dim is the multi-audio dim - # (e.g. multiple audios in the same example) - - dtype = next(self.audio_embed.parameters()).dtype - audio_embeds = [ - self.audio_embed( - features.unsqueeze(0).to(dtype), - audio_projection_mode=audio_projection_mode, - ) - for features in audio_features - ] - return audio_embeds - - def _parse_and_validate_image_input( - self, **kwargs: object - ) -> Phi4MMImagePixelInputs | None: - pixel_values = kwargs.get("image_pixel_values") - if pixel_values is None: - return None - - image_sizes = kwargs.get("image_sizes") - image_attention_mask = kwargs.get("image_attention_mask") - num_img_tokens = kwargs.get("num_img_tokens") - assert ( - image_sizes is not None - and image_attention_mask is not None - and num_img_tokens is not None - ), "Missing image inputs" - - return Phi4MMImagePixelInputs( - type="pixel_values", - pixel_values=pixel_values, - image_sizes=image_sizes, - image_attention_mask=image_attention_mask, - num_img_tokens=num_img_tokens, - ) - - def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} - - # Preserve the order of modalities if there are multiple of them - # from the order of kwargs. - for input_key in kwargs: - if ( - input_key in ("image_pixel_values", "image_embeds") - and "images" not in modalities - ): - modalities["images"] = self._parse_and_validate_image_input(**kwargs) - if ( - input_key in ("audio_input_features", "audio_embeds") - and "audios" not in modalities - ): - modalities["audios"] = self._parse_and_validate_audio_input(**kwargs) - - return modalities - - def _process_image_input( - self, image_input: Phi4MMImagePixelInputs - ) -> list[torch.Tensor]: - if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"].type(self.visual.dtype) - else: - dtype = next(self.image_embed.parameters()).dtype - pixel_values = image_input["pixel_values"].to(dtype) - image_sizes = image_input["image_sizes"] - image_attention_mask = image_input["image_attention_mask"] - image_embeds = self.image_embed( - pixel_values, image_sizes, image_attention_mask - ) - return image_embeds - - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: - return [] - - # The result multimodal_embeddings is tuple of tensors, with each - # tensor corresponding to a multimodal data item (image or video). - multimodal_embeddings: tuple[torch.Tensor, ...] = () - - # NOTE: It is important to iterate over the keys in this dictionary - # to preserve the order of the modalities. - audio_projection_mode = "speech" - for modality in modalities: - # make sure process images first - if modality == "images": - audio_projection_mode = "vision" - image_input = modalities["images"] - image_embeddings = self._process_image_input(image_input) - multimodal_embeddings += tuple(image_embeddings) - if modality == "audios": - audio_input = modalities["audios"] - audio_embeddings = self._process_audio_input( - audio_input, audio_projection_mode=audio_projection_mode - ) - multimodal_embeddings += tuple(audio_embeddings) - - return multimodal_embeddings - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs: object, - ) -> torch.Tensor: - if intermediate_tensors is not None: - inputs_embeds = None - - hidden_states = self.language_model( - input_ids, - positions, - intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="language_model.", - connector=[ - "img_projection", - "vision_speech_projection", - "speech_projection", - ], - tower_model=["image_embed", "audio_embed"], - ) - - def get_language_model(self) -> torch.nn.Module: - return self.language_model diff --git a/vllm/model_executor/models/qwen3_omni_moe_thinker.py b/vllm/model_executor/models/qwen3_omni_moe_thinker.py index fe825198dcaa4..e6979211b707f 100755 --- a/vllm/model_executor/models/qwen3_omni_moe_thinker.py +++ b/vllm/model_executor/models/qwen3_omni_moe_thinker.py @@ -494,7 +494,10 @@ class Qwen3Omni_VisionTransformer(nn.Module): cu_seqlens: torch.Tensor, ) -> torch.Tensor: max_seqlen = torch.zeros([], device=cu_seqlens.device) - if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: + if self.attn_backend in { + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + }: max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() return max_seqlen diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index d3b6268e7647b..a4a964bc7c1a6 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -374,7 +374,6 @@ _MULTIMODAL_MODELS = { ), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), - "Phi4MultimodalForCausalLM": ("phi4_multimodal", "Phi4MultimodalForCausalLM"), # noqa: E501 "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QwenVLForConditionalGeneration": ("qwen_vl", "QwenVLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 @@ -507,6 +506,7 @@ _PREVIOUSLY_SUPPORTED_MODELS = { "MotifForCausalLM": "0.10.2", "Phi3SmallForCausalLM": "0.9.2", "Phi4FlashForCausalLM": "0.10.2", + "Phi4MultimodalForCausalLM": "0.12.0", # encoder-decoder models except whisper # have been removed for V0 deprecation. "BartModel": "0.10.2", diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 19052c8d49e44..9f34090e31071 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -64,7 +64,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from .interfaces import IsAttentionFree, MultiModalEmbeddings, SupportsMultiModal -from .interfaces_base import default_pooling_type +from .interfaces_base import attn_type logger = init_logger(__name__) @@ -220,7 +220,7 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor): ) -@default_pooling_type("All") +@attn_type("attention_free") @MULTIMODAL_REGISTRY.register_processor( TerratorchMultiModalProcessor, info=TerratorchProcessingInfo, diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index b93a42ffd24c1..062547401c3cf 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -11,6 +11,7 @@ import pybase64 import torch from vllm.utils.import_utils import PlaceholderModule +from vllm.utils.serial_utils import tensor2base64 from .base import MediaIO @@ -135,8 +136,4 @@ class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]): return torch.load(filepath, weights_only=True) def encode_base64(self, media: torch.Tensor) -> str: - buffer = BytesIO() - torch.save(media, buffer) - buffer.seek(0) - binary_data = buffer.read() - return pybase64.b64encode(binary_data).decode("utf-8") + return tensor2base64(media) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 2b2c2f9cdc571..a2518d5fd3dc4 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -10,6 +10,7 @@ import sys from dataclasses import dataclass from typing import TYPE_CHECKING +import psutil import regex as re import torch @@ -147,11 +148,21 @@ class CpuPlatform(Platform): from vllm.utils.mem_constants import GiB_bytes kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE + node_dir = "/sys/devices/system/node" if kv_cache_space is None: - kv_cache_space = 4 * GiB_bytes # type: ignore + nodes = ( + [d for d in os.listdir(node_dir) if d.startswith("node")] + if os.path.exists(node_dir) + else [] + ) + num_numa_nodes = len(nodes) or 1 + free_cpu_memory = psutil.virtual_memory().total // num_numa_nodes + DEFAULT_CPU_MEM_UTILIZATION = 0.5 + kv_cache_space = int(free_cpu_memory * DEFAULT_CPU_MEM_UTILIZATION) + kv_cache_space_gib = kv_cache_space / GiB_bytes logger.warning_once( - "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) " - "for CPU backend is not set, using 4 by default." + "VLLM_CPU_KVCACHE_SPACE not set. Using " + f"{kv_cache_space_gib:.2f} GiB for KV cache." ) else: kv_cache_space *= GiB_bytes diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py index 35084c0e7cc86..e78ac4a5ebb37 100644 --- a/vllm/reasoning/basic_parsers.py +++ b/vllm/reasoning/basic_parsers.py @@ -64,8 +64,15 @@ class BaseThinkingReasoningParser(ReasoningParser): ) def is_reasoning_end(self, input_ids: list[int]) -> bool: + start_token_id = self.start_token_id end_token_id = self.end_token_id - return any(input_id == end_token_id for input_id in reversed(input_ids)) + + for i in range(len(input_ids) - 1, -1, -1): + if input_ids[i] == start_token_id: + return False + if input_ids[i] == end_token_id: + return True + return False def extract_content_ids(self, input_ids: list[int]) -> list[int]: """ diff --git a/vllm/tokenizers/__init__.py b/vllm/tokenizers/__init__.py index 42487f5f51651..67a6d7c8eb3d9 100644 --- a/vllm/tokenizers/__init__.py +++ b/vllm/tokenizers/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .deepseekv32 import DeepseekV32Tokenizer from .hf import HfTokenizer from .mistral import MistralTokenizer from .protocol import TokenizerLike @@ -21,4 +22,5 @@ __all__ = [ "get_tokenizer", "cached_tokenizer_from_config", "init_tokenizer_from_config", + "DeepseekV32Tokenizer", ] diff --git a/vllm/tokenizers/deepseek_v32_encoding.py b/vllm/tokenizers/deepseek_v32_encoding.py new file mode 100644 index 0000000000000..521bd92959312 --- /dev/null +++ b/vllm/tokenizers/deepseek_v32_encoding.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +# copy from https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/encoding/encoding_dsv32.py +import copy +import json +from typing import Any + +import regex as re + +# flake8: noqa: E501 +TOOLS_SYSTEM_TEMPLATE = """## Tools +You have access to a set of tools you can use to answer the user's question. +You can invoke functions by writing a "<{dsml_token}function_calls>" block like the following as part of your reply to the user: +<{dsml_token}function_calls> +<{dsml_token}invoke name="$FUNCTION_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$FUNCTION_NAME2"> +... + + +String and scalar parameters should be specified as is without any escaping or quotes, while lists and objects should use JSON format. The "string" attribute should be set to "true" for string type parameters and "false" for other types (numbers, booleans, arrays, objects). +If the thinking_mode is enabled, then after function results you should strongly consider outputting a thinking block. Here is an example: +<{dsml_token}function_calls> +... + + +... + +{thinking_start_token}...thinking about results{thinking_end_token} +Here are the functions available in JSONSchema format: + +{tool_schemas} + +""" + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" +system_msg_template: str = "{content}" +user_msg_template: str = "<|User|>{content}<|Assistant|>" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}<|end▁of▁sentence|>" +thinking_template = "{reasoning_content}" + +response_format_template: str = "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +tool_call_template: str = ( + '<{dsml_token}invoke name="{name}">\n{arguments}\n' +) +tool_calls_template = ( + "<{dsml_token}function_calls>\n{tool_calls}\n" +) + +tool_output_template: str = "\n{content}" + + +def to_json(value: Any) -> str: + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: dict[str, str]) -> str: + p_dsml_template = """<{dsml_token}parameter name="{key}" string="{is_str}">{value}""" + P_dsml_strs = [] + if isinstance(tool_call["arguments"], str): + arguments = json.loads(tool_call["arguments"]) + else: + arguments = tool_call["arguments"] + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments( + tool_name: str, tool_args: dict[str, tuple[str, str]] +) -> dict[str, str]: + def _decode_value(key: str, value: str, string: str): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = ( + "{" + + ", ".join( + [_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()] + ) + + "}" + ) + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools: list[dict[str, str | dict[str, Any]]]) -> str: + tools_json = [to_json(t) for t in tools] + + return TOOLS_SYSTEM_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages: list[dict[str, Any]]) -> int: + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +def render_message( + index: int, messages: list[dict[str, Any]], thinking_mode: str +) -> str: + assert 0 <= index < len(messages) + assert thinking_mode in ["chat", "thinking"], ( + f"Invalid thinking_mode `{thinking_mode}`" + ) + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning_content = msg.get("reasoning") or msg.get("reasoning_content") + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + + if response_format: + prompt += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + elif role == "developer": + assert content, f"Invalid message for role `{role}`: {msg}" + content_developer = "" + if tools: + content_developer += "\n\n" + render_tools(tools) + + if response_format: + content_developer += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + content_developer += "\n\n# The user's message is: {}".format(content) + + prompt += user_msg_template.format(content=content_developer) + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "user": + prompt += user_msg_template.format(content=content) + + if index == last_user_idx and thinking_mode == "thinking": + prompt += thinking_start_token + else: + prompt += thinking_end_token + + elif role == "tool": + prev_assistant_idx = index - 1 + assistant_msg = messages[prev_assistant_idx] + while prev_assistant_idx >= 0 and assistant_msg.get("role") == "tool": + prev_assistant_idx -= 1 + assistant_msg = messages[prev_assistant_idx] + + assert ( + index == 0 + or prev_assistant_idx >= 0 + and assistant_msg.get("role") == "assistant" + ), f"Invalid messages at {index}:\n{assistant_msg}" + + tool_call_order = index - prev_assistant_idx + assistant_tool_calls = assistant_msg.get("tool_calls") + assert assistant_tool_calls and len(assistant_tool_calls) >= tool_call_order, ( + "No tool calls but found tool output" + ) + + if tool_call_order == 1: + prompt += "\n\n" + + prompt += tool_output_template.format(content=content) + + if tool_call_order == len(assistant_tool_calls): + prompt += "\n" + + if index >= last_user_idx and thinking_mode == "thinking": + prompt += "\n\n" + thinking_start_token + else: + prompt += "\n\n" + thinking_end_token + + elif role == "assistant": + prev_assistant_idx = index + thinking_part = "" + + tool_calls_content = "" + if tool_calls: + tool_calls = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tool_call.get("name"), + arguments=encode_arguments_to_dsml(tool_call), + ) + for tool_call in tool_calls + ] + tool_calls_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, tool_calls="\n".join(tool_calls) + ) + + summary_content = content or "" + + if thinking_mode == "thinking" and index > last_user_idx: + assert reasoning_content or tool_calls, ( + f"ThinkingMode: {thinking_mode}, invalid message without reasoning_content/tool_calls `{msg}` after last user message" + ) + thinking_part = ( + thinking_template.format(reasoning_content=reasoning_content or "") + + thinking_end_token + ) + + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tool_calls_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + return prompt + + +def drop_thinking_messages( + messages: list[dict[str, Any]], last_user_idx: int | None = None +) -> list[dict[str, Any]]: + messages_wo_thinking: list[dict[str, Any]] = [] + last_user_idx = ( + find_last_user_index(messages) if last_user_idx is None else last_user_idx + ) + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in ["user", "system", "tool"] or idx >= last_user_idx: + messages_wo_thinking.append(msg) + continue + + elif role == "assistant": + msg_wo_thinking = copy.copy(msg) + msg_wo_thinking.pop("reasoning_content", None) + msg_wo_thinking.pop("reasoning", None) + messages_wo_thinking.append(msg_wo_thinking) + + return messages_wo_thinking + + +def encode_messages( + messages: list[dict[str, Any]], + thinking_mode: str, + context: list[dict[str, Any]] | None = None, + drop_thinking: bool = True, + add_default_bos_token: bool = True, +) -> str: + context = context if context else [] + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + if thinking_mode == "thinking" and drop_thinking: + full_messages = drop_thinking_messages(full_messages) + + for idx in range(len(messages)): + prompt += render_message( + idx + len(context), full_messages, thinking_mode=thinking_mode + ) + + return prompt + + +def _read_until_stop( + index: int, text: str, stop: list[str] +) -> tuple[int, str, None | str]: + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls(index: int, text: str): + tool_calls: list[dict[str, Any]] = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, _, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}invoke", tool_calls_end_token] + ) + assert _ == ">\n", "Tool call format error" + + if stop_token == tool_calls_end_token: + break + + assert stop_token is not None, "Missing special token" + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL + ) + assert len(p_tool_name) == 1, "Tool name format error" + tool_name = p_tool_name[0] + + tool_args: dict[str, tuple[str, str]] = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop( + index, text, [f"/{dsml_token}parameter"] + ) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + assert len(param_kv) == 1, "Parameter format error" + param_name, string, param_value = param_kv[0] + + assert param_name not in tool_args, "Duplicate parameter name" + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n", "Parameter format error" + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +# NOTE: This function is designed to parse only correctly +# formatted string and will not attempt to correct malformed output +# that may be generated by the model. +def parse_message_from_completion_text(text: str, thinking_mode: str): + summary_content, reasoning_content, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}function_calls" + + is_thinking, is_tool_calling = thinking_mode == "thinking", False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop( + index, text, [thinking_end_token, tool_calls_start_token] + ) + reasoning_content = content_delta + assert stop_token == thinking_end_token, "Invalid thinking format" + + index, content_delta, stop_token = _read_until_stop( + index, text, [eos_token, tool_calls_start_token] + ) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + assert stop_token == eos_token, "Invalid summary format" + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + assert not tool_ends_text, "Unexpected content after tool calls" + + assert len(text) == index and stop_token in [eos_token, None], ( + "Unexpected content at end" + ) + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + assert sp_token not in summary_content and sp_token not in reasoning_content, ( + "Unexpected special token in content" + ) + + return { + "role": "assistant", + "content": summary_content, + "reasoning_content": reasoning_content, + "reasoning": reasoning_content, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } diff --git a/vllm/tokenizers/deepseekv32.py b/vllm/tokenizers/deepseekv32.py new file mode 100644 index 0000000000000..1140357cf861d --- /dev/null +++ b/vllm/tokenizers/deepseekv32.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path + +from transformers import BatchEncoding + +from .deepseek_v32_encoding import encode_messages +from .hf import HfTokenizer, TokenizerLike +from .registry import TokenizerRegistry + + +@TokenizerRegistry.register("deepseek_v32") +class DeepseekV32Tokenizer(HfTokenizer): + def __init__(self, tokenizer: TokenizerLike): + self.tokenizer = tokenizer + self.name_or_path = ( + tokenizer.name_or_path if hasattr(tokenizer, "name_or_path") else "" + ) + + @classmethod + def from_pretrained( + cls, + path_or_repo_id: str | Path, + *args, + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, + ) -> "TokenizerLike": + tokenizer = super().from_pretrained( + path_or_repo_id, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + download_dir=download_dir, + **kwargs, + ) + return DeepseekV32Tokenizer(tokenizer) + + def apply_chat_template(self, messages, tools=None, **kwargs): + thinking = kwargs.get("thinking", False) + thinking_mode = "thinking" + if not thinking: + thinking_mode = "chat" + conversation = kwargs.get("conversation", messages) + messages = conversation.copy() + drop_thinking = True + if tools is not None and len(tools) > 0: + messages.insert(0, {"role": "system"}) + messages[0]["tools"] = tools + drop_thinking = False + encode_config = dict(thinking_mode=thinking_mode, drop_thinking=drop_thinking) + prompt_str = encode_messages(messages, **encode_config) # type: ignore + return prompt_str + + @property + def all_special_tokens(self) -> list[str]: + return self.tokenizer.all_special_tokens + + @property + def all_special_ids(self) -> list[int]: + return self.tokenizer.all_special_ids + + @property + def bos_token_id(self) -> int: + return self.tokenizer.bos_token_id + + @property + def eos_token_id(self) -> int: + return self.tokenizer.eos_token_id + + @property + def pad_token_id(self) -> int: + return self.tokenizer.pad_token_id + + @property + def is_fast(self) -> bool: + return self.tokenizer.is_fast + + @property + def vocab_size(self) -> int: + return self.tokenizer.vocab_size + + @property + def max_token_id(self) -> int: + return self.tokenizer.max_token_id + + @property + def truncation_side(self) -> str: + return self.tokenizer.truncation_side + + def __hash__(self) -> int: + return hash(id(self)) + + def __len__(self) -> int: + # is an added token in DeepseekV32 tokenizer + return self.vocab_size + len(self.get_added_vocab()) + + def __call__( + self, + text: str | list[str], + text_pair: str | None = None, + add_special_tokens: bool = True, + truncation: bool = False, + max_length: int | None = None, + ) -> "BatchEncoding": + return self.tokenizer( + text, + text_pair=text_pair, + add_special_tokens=add_special_tokens, + truncation=truncation, + max_length=max_length, + ) + + def get_vocab(self) -> dict[str, int]: + return self.tokenizer.get_vocab() + + def get_added_vocab(self) -> dict[str, int]: + return self.tokenizer.get_added_vocab() + + def encode( + self, + text: str, + truncation: bool | None = None, + max_length: int | None = None, + add_special_tokens: bool = True, + ) -> list[int]: + return self.tokenizer.encode( + text, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, + ) + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + return self.tokenizer.convert_tokens_to_string(tokens) + + def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: + return self.tokenizer.decode(ids, skip_special_tokens=skip_special_tokens) + + def convert_ids_to_tokens( + self, + ids: list[int], + skip_special_tokens: bool = False, + ) -> list[str]: + return self.tokenizer.convert_ids_to_tokens( + ids, skip_special_tokens=skip_special_tokens + ) diff --git a/vllm/transformers_utils/runai_utils.py b/vllm/transformers_utils/runai_utils.py index eac4294bb59cd..041056720a96b 100644 --- a/vllm/transformers_utils/runai_utils.py +++ b/vllm/transformers_utils/runai_utils.py @@ -18,9 +18,7 @@ SUPPORTED_SCHEMES = ["s3://", "gs://"] try: from runai_model_streamer import list_safetensors as runai_list_safetensors from runai_model_streamer import pull_files as runai_pull_files -except (ImportError, OSError): - # see https://github.com/run-ai/runai-model-streamer/issues/26 - # OSError will be raised on arm64 platform +except ImportError: runai_model_streamer = PlaceholderModule("runai_model_streamer") # type: ignore[assignment] runai_pull_files = runai_model_streamer.placeholder_attr("pull_files") runai_list_safetensors = runai_model_streamer.placeholder_attr("list_safetensors") diff --git a/vllm/utils/hashing.py b/vllm/utils/hashing.py index edf1e9cb34e56..f01c6b074ffeb 100644 --- a/vllm/utils/hashing.py +++ b/vllm/utils/hashing.py @@ -11,6 +11,17 @@ from typing import Any import cbor2 +try: + # It is important that this remains an optional dependency. + # It would not be allowed in environments with strict security controls, + # so it's best not to have it installed when not in use. + import xxhash as _xxhash + + if not hasattr(_xxhash, "xxh3_128_digest"): + _xxhash = None +except ImportError: # pragma: no cover + _xxhash = None + def sha256(input: Any) -> bytes: """Hash any picklable Python object using SHA-256. @@ -47,6 +58,27 @@ def sha256_cbor(input: Any) -> bytes: return hashlib.sha256(input_bytes).digest() +def _xxhash_digest(input_bytes: bytes) -> bytes: + if _xxhash is None: + raise ModuleNotFoundError( + "xxhash is required for the 'xxhash' prefix caching hash algorithms. " + "Install it via `pip install xxhash`." + ) + return _xxhash.xxh3_128_digest(input_bytes) + + +def xxhash(input: Any) -> bytes: + """Hash picklable objects using xxHash.""" + input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) + return _xxhash_digest(input_bytes) + + +def xxhash_cbor(input: Any) -> bytes: + """Hash objects serialized with CBOR using xxHash.""" + input_bytes = cbor2.dumps(input, canonical=True) + return _xxhash_digest(input_bytes) + + def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: """Get a hash function by name, or raise an error if the function is not found. @@ -60,6 +92,10 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], bytes]: return sha256 if hash_fn_name == "sha256_cbor": return sha256_cbor + if hash_fn_name == "xxhash": + return xxhash + if hash_fn_name == "xxhash_cbor": + return xxhash_cbor raise ValueError(f"Unsupported hash function: {hash_fn_name}") diff --git a/vllm/utils/serial_utils.py b/vllm/utils/serial_utils.py index b89fa6ce4db66..a6d717e03d37d 100644 --- a/vllm/utils/serial_utils.py +++ b/vllm/utils/serial_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import base64 +import io import sys from dataclasses import dataclass from typing import Literal @@ -52,6 +53,15 @@ Endianness = Literal["native", "big", "little"] EncodingFormat = Literal["float", "base64", "bytes"] +def tensor2base64(x: torch.Tensor) -> str: + with io.BytesIO() as buf: + torch.save(x, buf) + buf.seek(0) + binary_data = buf.read() + + return base64.b64encode(binary_data).decode("utf-8") + + def tensor2binary( tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness ) -> bytes: diff --git a/vllm/utils/system_utils.py b/vllm/utils/system_utils.py index a4eb8f4d4fd7d..76cac59c18098 100644 --- a/vllm/utils/system_utils.py +++ b/vllm/utils/system_utils.py @@ -204,6 +204,10 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: def decorate_logs(process_name: str | None = None) -> None: """Decorate stdout/stderr with process name and PID prefix.""" + # Respect VLLM_CONFIGURE_LOGGING environment variable + if not envs.VLLM_CONFIGURE_LOGGING: + return + if process_name is None: process_name = get_mp_context().current_process().name diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index fe92f6570501c..a2a6eeeb16b24 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -31,6 +31,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( vllm_is_batch_invariant, ) +from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.v1.attention.backends.utils import ( @@ -927,7 +928,18 @@ def get_kernel_options( if torch.cuda.is_available(): device_props = torch.cuda.get_device_properties() - max_shared_memory = device_props.shared_memory_per_block_optin + # ROCm doesn't expose shared_memory_per_block_optin attribute + # AMD GPUs typically have 64KB LDS (Local Data Share) per workgroup + if hasattr(device_props, "shared_memory_per_block_optin"): + max_shared_memory = device_props.shared_memory_per_block_optin + elif current_platform.is_rocm(): + # ROCm fallback: use 64KB + max_shared_memory = 65536 + else: + raise RuntimeError( + "Unable to determine shared memory size on this hardware." + ) + if max_shared_memory < 144 * 1024: block_m_candidate = ensure_divisible( max(1, block_m_candidate // 2), block_m diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 602eb81beb010..774200deed158 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -12,7 +12,7 @@ from typing import Any, NewType, TypeAlias, overload from vllm import envs from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils.hashing import sha256_cbor +from vllm.utils.hashing import sha256_cbor, xxhash_cbor from vllm.utils.math_utils import cdiv from vllm.utils.mem_constants import GiB_bytes from vllm.v1.kv_cache_interface import ( @@ -83,18 +83,19 @@ logger = init_logger(__name__) # # The function `init_none_hash` initializes this variable globally. NONE_HASH: BlockHash +_CBOR_HASH_FUNCTIONS = frozenset({sha256_cbor, xxhash_cbor}) def init_none_hash(hash_fn: Callable[[Any], bytes]): global NONE_HASH hash_seed = os.getenv("PYTHONHASHSEED") - if hash_seed is None and hash_fn is sha256_cbor: + if hash_seed is None and hash_fn in _CBOR_HASH_FUNCTIONS: logger.warning( "PYTHONHASHSEED is not set. This will lead to non-reproducible " - "block-hashes when using sha256_cbor as the hash function." - "Consider setting PYTHONHASHSEED to a fixed value for " - "reproducibility." + "block-hashes when using CBOR-based hash functions such as " + "sha256_cbor or xxhash_cbor. Consider setting PYTHONHASHSEED to a " + "fixed value for reproducibility." ) if hash_seed is None: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index d21cdf04ead26..8772f2e488dc0 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -409,8 +409,6 @@ class LLMEngine: return self.collective_rpc("apply_model", args=(func,)) def __del__(self): - if ( - dp_group := getattr(self, "dp_group", None) - and not self.external_launcher_dp - ): + dp_group = getattr(self, "dp_group", None) + if dp_group is not None and not self.external_launcher_dp: stateless_destroy_torch_distributed_process_group(dp_group) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 88ac6b4aeb4bb..546eacebf83e5 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -89,7 +89,7 @@ class LogprobsTensors(NamedTuple): # [num_reqs, ] # The shape of each element depends on the pooler used -PoolerOutput = torch.Tensor | list[torch.Tensor] +PoolerOutput = list[torch.Tensor | None] | torch.Tensor | None @dataclass diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 7bd2c7415dafe..acd1a00e87553 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import torch from vllm.pooling_params import PoolingParams +from vllm.tasks import PoolingTask from vllm.utils.platform_utils import is_pin_memory_available pin_memory = is_pin_memory_available() @@ -16,6 +17,7 @@ class PoolingCursor: first_token_indices_gpu: torch.Tensor last_token_indices_gpu: torch.Tensor prompt_lens_cpu: torch.Tensor + seq_lens_cpu: torch.Tensor num_scheduled_tokens_cpu: torch.Tensor def __getitem__(self, indices: slice): @@ -24,12 +26,25 @@ class PoolingCursor: first_token_indices_gpu=self.first_token_indices_gpu[indices], last_token_indices_gpu=self.last_token_indices_gpu[indices], prompt_lens_cpu=self.prompt_lens_cpu[indices], + seq_lens_cpu=self.seq_lens_cpu[indices], num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices], ) def is_partial_prefill(self): return not torch.all(self.prompt_lens_cpu == self.num_scheduled_tokens_cpu) + def is_finished(self): + return self.prompt_lens_cpu == self.seq_lens_cpu + + +class PoolingStates: + def __init__(self): + # for chunked prefill with ALL pooling + self.hidden_states_cache: list[torch.Tensor] = [] + + def clean(self): + self.hidden_states_cache.clear() + @dataclass class PoolingMetadata: @@ -38,8 +53,21 @@ class PoolingMetadata: prompt_lens: torch.Tensor # CPU Tensor prompt_token_ids: torch.Tensor | None pooling_params: list[PoolingParams] + pooling_states: list[PoolingStates] pooling_cursor: PoolingCursor | None = None + def __post_init__(self) -> None: + pooling_params = self.pooling_params + + tasks: list[PoolingTask] = [ + task + for pooling_param in pooling_params + if (task := pooling_param.task) is not None + ] + assert len(pooling_params) == len(tasks) + + self.tasks = tasks + def __getitem__(self, indices: slice): return PoolingMetadata( prompt_lens=self.prompt_lens[indices], @@ -47,21 +75,36 @@ class PoolingMetadata: if self.prompt_token_ids is None else self.prompt_token_ids[indices], pooling_params=self.pooling_params[indices], + pooling_states=self.pooling_states[indices], pooling_cursor=None if self.pooling_cursor is None else self.pooling_cursor[indices], ) + def get_prompt_token_ids(self) -> list[torch.Tensor]: + prompt_token_ids = self.prompt_token_ids + assert prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`" + ) + + return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] + def build_pooling_cursor( - self, num_scheduled_tokens: list[int], device: torch.device + self, + num_scheduled_tokens: list[int], + seq_lens_cpu: torch.Tensor, + device: torch.device, ): self.pooling_cursor = build_pooling_cursor( - num_scheduled_tokens, self.prompt_lens, device + num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device ) def build_pooling_cursor( - num_scheduled_tokens: list[int], prompt_lens: torch.Tensor, device: torch.device + num_scheduled_tokens: list[int], + seq_lens_cpu: torch.Tensor, + prompt_lens: torch.Tensor, + device: torch.device, ): assert len(prompt_lens) == len(num_scheduled_tokens) @@ -78,5 +121,6 @@ def build_pooling_cursor( first_token_indices_gpu=cumsum[:n_seq], last_token_indices_gpu=cumsum[1:] - 1, prompt_lens_cpu=prompt_lens, + seq_lens_cpu=seq_lens_cpu, num_scheduled_tokens_cpu=num_scheduled_tokens_cpu, ) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 516c76a5e4b15..ead7a3619dea5 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.collection_utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors -from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates from vllm.v1.sample.logits_processor import ( BatchUpdateBuilder, LogitsProcessors, @@ -33,7 +33,6 @@ class CachedRequestState: prompt_token_ids: list[int] | None mm_features: list[MultiModalFeatureSpec] sampling_params: SamplingParams | None - pooling_params: PoolingParams | None generator: torch.Generator | None block_ids: tuple[list[int], ...] @@ -51,11 +50,18 @@ class CachedRequestState: # Used when both async_scheduling and spec_decode are enabled. prev_num_draft_len: int = 0 + # for pooling models + pooling_params: PoolingParams | None = None + pooling_states: PoolingStates | None = None + def __post_init__(self): self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( self.prompt_token_ids, self.prompt_embeds ) + if self.pooling_params is not None: + self.pooling_states = PoolingStates() + @property def num_tokens(self) -> int: return self.num_prompt_tokens + len(self.output_token_ids) @@ -255,7 +261,9 @@ class InputBatch: # This is updated each time the batch constituents change. self.sampling_metadata = self._make_sampling_metadata() + # for pooling models self.pooling_params: dict[str, PoolingParams] = {} + self.pooling_states: dict[str, PoolingStates] = {} # Cached reference to the GPU tensor of previously sampled tokens self.prev_sampled_token_ids: torch.Tensor | None = None @@ -413,7 +421,11 @@ class InputBatch: sampling_params.bad_words_token_ids ) elif pooling_params := request.pooling_params: + pooling_states = request.pooling_states + assert pooling_states is not None + self.pooling_params[req_id] = pooling_params + self.pooling_states[req_id] = pooling_states self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids ) @@ -469,6 +481,7 @@ class InputBatch: if self.is_pooling_model: self.pooling_params.pop(req_id, None) + self.pooling_states.pop(req_id, None) return req_index self.greedy_reqs.discard(req_id) @@ -837,13 +850,19 @@ class InputBatch: assert len(self.req_ids) == len(self.pooling_params) return [self.pooling_params[req_id] for req_id in self.req_ids] + def get_pooling_states(self) -> list[PoolingStates]: + assert len(self.req_ids) == len(self.pooling_states) + return [self.pooling_states[req_id] for req_id in self.req_ids] + def get_pooling_metadata(self) -> PoolingMetadata: pooling_params = self.get_pooling_params() + pooling_states = self.get_pooling_states() return PoolingMetadata( prompt_lens=torch.from_numpy(self.num_prompt_tokens[: self.num_reqs]), prompt_token_ids=self.sampling_metadata.prompt_token_ids, pooling_params=pooling_params, + pooling_states=pooling_states, ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ea21bb16f6c74..9b09cd6df45fe 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -132,7 +132,7 @@ from vllm.v1.outputs import ( SamplerOutput, make_empty_encoder_model_runner_output, ) -from vllm.v1.pool.metadata import PoolingMetadata +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor.interface import LogitsProcessor from vllm.v1.sample.metadata import SamplingMetadata @@ -2401,20 +2401,6 @@ class GPUModelRunner( supported_tasks = list(model.pooler.get_supported_tasks()) - if self.scheduler_config.enable_chunked_prefill: - if "token_embed" in supported_tasks: - supported_tasks.remove("token_embed") - if "token_classify" in supported_tasks: - supported_tasks.remove("token_classify") - - logger.debug_once( - "Chunked prefill is not supported with " - "token_embed and token_classify tasks " - "which using ALL pooling. " - "Please turn off chunked prefill by " - "`--no-enable-chunked-prefill` before using it." - ) - if "score" in supported_tasks: num_labels = getattr(self.model_config.hf_config, "num_labels", 0) if num_labels != 1: @@ -2491,11 +2477,12 @@ class GPUModelRunner( ) hidden_states = hidden_states[:num_scheduled_tokens] + seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] + pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np.tolist(), device=hidden_states.device + num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device ) - seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] model = cast(VllmModelForPooling, self.model) raw_pooler_output: PoolerOutput = model.pooler( @@ -2503,7 +2490,7 @@ class GPUModelRunner( pooling_metadata=pooling_metadata, ) raw_pooler_output = json_map_leaves( - lambda x: x.to("cpu", non_blocking=True), + lambda x: x.to("cpu", non_blocking=True) if x is not None else x, raw_pooler_output, ) self._sync_device() @@ -4358,10 +4345,13 @@ class GPUModelRunner( prompt_lens=dummy_prompt_lens, prompt_token_ids=dummy_token_ids, pooling_params=[dummy_pooling_params] * num_reqs, + pooling_states=[PoolingStates() for i in range(num_reqs)], ) dummy_metadata.build_pooling_cursor( - num_scheduled_tokens_list, device=hidden_states.device + num_scheduled_tokens_list, + seq_lens_cpu=dummy_prompt_lens, + device=hidden_states.device, ) try: @@ -4388,22 +4378,12 @@ class GPUModelRunner( supported_pooling_tasks = self.get_supported_pooling_tasks() if not supported_pooling_tasks: - if self.scheduler_config.enable_chunked_prefill: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks with chunked prefill enabled. " - "Please add --no-enable-chunked-prefill to your " - "config or CLI args. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) - else: - raise RuntimeError( - f"Model {self.model_config.model} does not support " - "any pooling tasks. See " - "https://docs.vllm.ai/en/latest/models/pooling_models.html " - "to learn more." - ) + raise RuntimeError( + f"Model {self.model_config.model} does not support " + "any pooling tasks. See " + "https://docs.vllm.ai/en/latest/models/pooling_models.html " + "to learn more." + ) output_size = dict[PoolingTask, float]() for task in supported_pooling_tasks: