Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-25 08:54:32 -07:00
commit 65f93694be
60 changed files with 3105 additions and 1153 deletions

View File

@ -18,14 +18,15 @@ Easy, fast, and cheap LLM serving for everyone
*Latest News* 🔥
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152).
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
<details>
<summary>Previous News</summary>
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).

View File

@ -59,6 +59,12 @@ become available.
<td style="text-align: center;"></td>
<td><code>synthetic</code></td>
</tr>
<tr>
<td><strong>RandomMultiModal (Image/Video)</strong></td>
<td style="text-align: center;">🟡</td>
<td style="text-align: center;">🚧</td>
<td><code>synthetic</code> </td>
</tr>
<tr>
<td><strong>Prefix Repetition</strong></td>
<td style="text-align: center;"></td>
@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \
--endpoint /v1/chat/completion
```
### Synthetic Random Images (random-mm)
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
Notes:
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
- Video sampling is not yet implemented.
Start the server (example):
```bash
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
--dtype bfloat16 \
--max-model-len 16384 \
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
--mm-processor-kwargs max_pixels=1003520
```
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens:
```bash
vllm bench serve \
--backend openai-chat \
--model Qwen/Qwen2.5-VL-3B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name random-mm \
--num-prompts 100 \
--max-concurrency 10 \
--random-prefix-len 25 \
--random-input-len 300 \
--random-output-len 40 \
--random-range-ratio 0.2 \
--random-mm-base-items-per-request 2 \
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
--request-rate inf \
--ignore-eos \
--seed 42
```
The number of items per request can be controlled by passing multiple image buckets:
```bash
--random-mm-base-items-per-request 2 \
--random-mm-num-mm-items-range-ratio 0.5 \
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
```
Flags specific to `random-mm`:
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
Behavioral notes:
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
How sampling works:
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
</details>

View File

@ -2,6 +2,7 @@
We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH)
- [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152).
- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing)
- [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).

View File

@ -196,6 +196,13 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
!!! note
API server scale-out is only available for online inference.
!!! warning
By default, 8 CPU threads are used in each API server to load media items (e.g. images)
from request data.
If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT`
to avoid CPU resource exhaustion.
!!! note
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
because it requires a one-to-one correspondance between API and engine core processes.

View File

@ -18,7 +18,7 @@ prometheus_client >= 0.18.0
pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
lm-format-enforcer == 0.11.3
llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines_core == 0.2.10 ; platform_machine != "s390x"
outlines == 0.1.11 ; platform_machine == "s390x"

View File

@ -0,0 +1,344 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
from typing import Any, NamedTuple, Optional, cast
import numpy as np
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
SampleRequest)
@pytest.fixture(scope="session")
def hf_tokenizer() -> PreTrainedTokenizerBase:
# Use a small, commonly available tokenizer
return AutoTokenizer.from_pretrained("gpt2")
class Params(NamedTuple):
num_requests: int
prefix_len: int
range_ratio: float
input_len: int
output_len: int
@pytest.fixture(scope="session")
def random_dataset_params() -> Params:
return Params(num_requests=16,
prefix_len=7,
range_ratio=0.3,
input_len=50,
output_len=20)
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
"""Project a SampleRequest into a comparable tuple."""
return (req.prompt, req.prompt_len, req.expected_output_len)
def _collect_samples(dataset: RandomDataset,
tokenizer: PreTrainedTokenizerBase,
num_requests: int = 16,
prefix_len: int = 7,
range_ratio: float = 0.3,
input_len: int = 50,
output_len: int = 20) -> list[tuple[str, int, int]]:
samples = dataset.sample(
tokenizer=tokenizer,
num_requests=num_requests,
prefix_len=prefix_len,
range_ratio=range_ratio,
input_len=input_len,
output_len=output_len,
)
return [_fingerprint_sample(s) for s in samples]
@pytest.mark.benchmark
def test_random_dataset_same_seed(
hf_tokenizer: PreTrainedTokenizerBase,
random_dataset_params: Params) -> None:
"""Same seed should yield identical outputs, even if global RNGs change.
This guards against accidental reliance on Python's random or np.random
in RandomDataset after moving to numpy.default_rng.
"""
p = random_dataset_params
common_seed = 123
dataset_a = RandomDataset(random_seed=common_seed)
dataset_b = RandomDataset(random_seed=common_seed)
a = _collect_samples(dataset_a,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
# Perturb global RNG state to ensure isolation
random.seed(999)
_ = [random.random() for _ in range(100)]
np.random.seed(888)
_ = [np.random.random() for _ in range(100)]
b = _collect_samples(dataset_b,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
assert a == b
@pytest.mark.benchmark
def test_random_dataset_different_seeds(
hf_tokenizer: PreTrainedTokenizerBase,
random_dataset_params: Params) -> None:
"""Different seeds should change outputs with overwhelming likelihood."""
p = random_dataset_params
seed_a = 0
dataset_a = RandomDataset(random_seed=seed_a)
a = _collect_samples(dataset_a,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
seed_b = 999
dataset_b = RandomDataset(random_seed=seed_b)
# Perturb global RNG with same seed as dataset_a to ensure isolation
random.seed(seed_a)
np.random.seed(seed_a)
b = _collect_samples(dataset_b,
hf_tokenizer,
num_requests=p.num_requests,
prefix_len=p.prefix_len,
range_ratio=p.range_ratio,
input_len=p.input_len,
output_len=p.output_len)
assert a != b
# -----------------------------
# RandomMultiModalDataset tests
# -----------------------------
def _mm_fingerprint_sample(
req: SampleRequest,
) -> tuple[str, int, int, int, list[str]]:
"""Create a compact fingerprint for multimodal samples.
Includes:
- prompt string
- prompt_len
- expected_output_len
- count of multimodal items
- per-item type and URL prefix (e.g., 'data:image/jpeg;base64,')
"""
items = req.multi_modal_data or []
item_prefixes: list[str] = []
for it in items:
if isinstance(it, dict) and it.get("type") == "image_url":
url = it.get("image_url", {}).get("url", "")
# Only keep a short identifying prefix to avoid huge strings
item_prefixes.append(f"image:{url[:22]}")
elif isinstance(it, dict) and it.get("type") == "video_url":
url = it.get("video_url", {}).get("url", "")
item_prefixes.append(f"video:{url[:22]}")
else:
item_prefixes.append("unknown:")
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
item_prefixes)
def _collect_mm_samples(
dataset: RandomMultiModalDataset,
tokenizer: PreTrainedTokenizerBase,
*,
num_requests: int = 8,
prefix_len: int = 3,
range_ratio: float = 0.0,
input_len: int = 20,
output_len: int = 5,
base_items_per_request: int = 2,
num_mm_items_range_ratio: float = 0.0,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
bucket_config: Optional[dict[tuple[int, int, int], float]] = None,
enable_multimodal_chat: bool = False,
) -> list[SampleRequest]:
if limit_mm_per_prompt is None:
limit_mm_per_prompt = {"image": 5, "video": 0}
if bucket_config is None:
bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5}
return dataset.sample(
tokenizer=tokenizer,
num_requests=num_requests,
prefix_len=prefix_len,
range_ratio=range_ratio,
input_len=input_len,
output_len=output_len,
base_items_per_request=base_items_per_request,
num_mm_items_range_ratio=num_mm_items_range_ratio,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
enable_multimodal_chat=enable_multimodal_chat,
)
@pytest.mark.benchmark
def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None:
seed = 42
ds_a = RandomMultiModalDataset(random_seed=seed)
ds_b = RandomMultiModalDataset(random_seed=seed)
a = _collect_mm_samples(ds_a, hf_tokenizer)
b = _collect_mm_samples(ds_b, hf_tokenizer)
fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa == fb
@pytest.mark.benchmark
def test_random_mm_different_seeds(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds_a = RandomMultiModalDataset(random_seed=0)
ds_b = RandomMultiModalDataset(random_seed=999)
a = _collect_mm_samples(ds_a, hf_tokenizer)
b = _collect_mm_samples(ds_b, hf_tokenizer)
fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa != fb
@pytest.mark.benchmark
def test_random_mm_respects_limits(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# Requesting 3 items with a per-prompt limit of 1 should error per current
# design (dataset refuses to silently clamp below the requested baseline).
with pytest.raises(ValueError):
_collect_mm_samples(
ds,
hf_tokenizer,
num_requests=12,
base_items_per_request=3,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 1, "video": 0},
bucket_config={(32, 32, 1): 1.0},
)
@pytest.mark.benchmark
def test_random_mm_zero_prob_entries_are_removed(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# Second bucket has zero probability and should be ignored after
# normalization
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=6,
base_items_per_request=2,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 10, "video": 0},
bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0},
)
for s in samples:
assert isinstance(s.multi_modal_data, list)
typed_mm = cast(list[dict[str, Any]], s.multi_modal_data)
for it in typed_mm:
assert it.get("type") == "image_url"
@pytest.mark.benchmark
def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
ds = RandomMultiModalDataset(random_seed=0)
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=0,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 5, "video": 0},
bucket_config={(32, 32, 1): 1.0},
)
for s in samples:
assert s.multi_modal_data == []
@pytest.mark.benchmark
def test_random_mm_num_items_per_prompt(
hf_tokenizer: PreTrainedTokenizerBase) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# Fixed number of images per prompt
# set num_mm_items_range_ratio to 0.0
# TODO: modify video values when video sampling is implemented
samples_fixed_items = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=3,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 3, "video": 0},
bucket_config={(32, 32, 1): 1.0},
)
# Must have 5 requests each with 3 mm items per prompt
assert len(samples_fixed_items) == 5
for s in samples_fixed_items:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 3
for it in mm_data:
assert it.get("type") == "image_url"
@pytest.mark.benchmark
def test_random_mm_bucket_config_not_mutated(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
ds = RandomMultiModalDataset(random_seed=0)
# This bucket config is not normalized to sum to 1
# and has more buckets than requested images
original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3}
# Keep a snapshot to compare after sampling
snapshot = dict(original)
_ = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=4,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt={"image": 1, "video": 0},
bucket_config=original,
)
# Ensure the original dict content is unchanged
assert original == snapshot
# Vary number of mm items per prompt
# set num_mm_items_range_ratio to 0.5
samples_varying_items = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=2,
num_mm_items_range_ratio=0.5,
limit_mm_per_prompt={"image": 4, "video": 0},
bucket_config={(32, 32, 1): 1.0},
)
# Must have 5 requests each with less than 4 mm items per prompt
# but at least 1 mm item per prompt
assert len(samples_varying_items) == 5
for s in samples_varying_items:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) <= 4
assert len(mm_data) >= 1
for it in mm_data:
assert it.get("type") == "image_url"

View File

@ -9,12 +9,17 @@ import pytest
import torch
from packaging import version
from vllm import SamplingParams
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config)
from vllm.v1.attention.backends.flex_attention import (
FlexAttentionMetadataBuilder)
from ..models.utils import check_embeddings_close
from ..models.utils import check_embeddings_close, check_logprobs_close
TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
def set_seed(seed):
@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.
This test compares the outputs from the FlexAttention backend with
the default backend, ensuring they are identical when using the same seed.
the default backend, ensuring they are similar when using the same seed.
"""
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42
max_tokens = 24
num_logprobs = 5
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
sampling_params = SamplingParams(temperature=0.0,
top_p=1.0,
seed=seed,
max_tokens=max_tokens)
# Run with flex attention
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True) as llm_flex:
output_flex = llm_flex.generate(prompts, sampling_params)
output_flex = llm_flex.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs)
# Run with default backend
with monkeypatch.context() as m:
@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True) as llm_default:
output_default = llm_default.generate(prompts, sampling_params)
enforce_eager=True,
gpu_memory_utilization=0.85) as llm_default:
output_default = llm_default.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs)
# Compare outputs from both backends
for i, (flex_result,
default_result) in enumerate(zip(output_flex, output_default)):
prompt = prompts[i]
flex_text = flex_result[1][0]
default_text = default_result[1][0]
assert flex_text == default_text, (
f"FlexAttention output doesn't match default for: {prompt!r}\n"
f"FlexAttention: {flex_text!r}\n"
f"Default: {default_text!r}")
check_logprobs_close(
outputs_0_lst=output_flex,
outputs_1_lst=output_default,
name_0="flex",
name_1="default",
)
@pytest.mark.skipif(
@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
)
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_block_mask_direct_vs_slow_path():
"""Test that direct path block mask is a superset of slow path.
The direct path may include extra blocks for performance (over-estimation),
but must include all blocks that the slow path determines are necessary.
"""
device = torch.device("cuda")
vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B",
block_size=16,
max_model_len=1024)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
# Use a mixed batch that will create groups spanning multiple sequences
batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256],
query_lens=[33, 5, 32, 64],
name="test_mixed_batch")
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device)
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config,
device)
metadata_direct = builder.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
builder.direct_build = False
metadata_slow = builder.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
assert metadata_direct.block_mask is not None
assert metadata_slow.block_mask is not None
# Extract block indices for comparison, B, H are the same
direct_indices = metadata_direct.block_mask.kv_indices[0, 0]
slow_indices = metadata_slow.block_mask.kv_indices[0, 0]
direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0]
slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0]
# main test: every block needed by slow path must be in direct path
num_groups = direct_num.shape[0]
all_contained = True
missing_details = []
for group_idx in range(num_groups):
direct_blocks = set(
direct_indices[group_idx, :direct_num[group_idx]].tolist())
slow_blocks = set(
slow_indices[group_idx, :slow_num[group_idx]].tolist())
missing_blocks = slow_blocks - direct_blocks
if missing_blocks:
all_contained = False
missing_details.append(
f"Group {group_idx}: missing {sorted(missing_blocks)}")
assert all_contained, (
"Direct path is missing blocks required by slow path:\n" +
"\n".join(missing_details))
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo
from .mteb_utils import mteb_test_embed_models
# ST models with projector (Dense) layers
ST_PROJECTOR_MODELS = [
CLSPoolingEmbedModelInfo(
"TencentBAC/Conan-embedding-v1",
architecture="BertModel",
enable_test=True,
),
]
@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
model_info: EmbedModelInfo) -> None:
mteb_test_embed_models(hf_runner, vllm_runner, model_info)

View File

@ -17,13 +17,11 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement, apply_text_matches,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby
from .utils import random_image
@ -75,12 +73,15 @@ from .utils import random_image
),
],
)
@pytest.mark.parametrize("start_idx", [0, 4, 8])
# yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected):
result = list(iter_token_matches(token_ids, match_ids))
def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
result = list(iter_token_matches(token_ids, match_ids,
start_idx=start_idx))
# Manually constructed results
assert [item._asdict() for item in result] == expected
assert [item._asdict() for item in result
] == [item for item in expected if item["start_idx"] >= start_idx]
# Invariants
match_lens = [end - start for start, end in result]
@ -241,21 +242,23 @@ def test_find_token_matches(
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
prompt_updates = {
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_updates)
}
result = {
key: list(update.iter_token_matches(prompt, mock_tokenizer))
for key, update in prompt_updates.items()
}
# Only displayed on error
print("result:", result)
# Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert {
key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, [])
for item in result.get(key, [])
]
for key in expected_by_key
} == expected_by_key
@ -388,21 +391,23 @@ def test_find_text_matches(
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
prompt_updates = {
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_updates)
}
result = {
key: list(update.iter_text_matches(prompt, mock_tokenizer))
for key, update in prompt_updates.items()
}
# Only displayed on error
print("result:", result)
# Manually constructed results
result_groups = dict(full_groupby(result, key=lambda x: x.modality))
assert {
key: [
dict(start_idx=item.start_idx, end_idx=item.end_idx)
for item in result_groups.get(key, [])
for item in result.get(key, [])
]
for key in expected_by_key
} == expected_by_key
@ -552,39 +557,35 @@ def test_find_update_text(
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
for mm_count, expected in expected_by_mm_count.items():
result = apply_text_matches(
mm_prompt_updates = {
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
for key, target in target_by_key.items()
}
new_prompt, result = apply_text_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
mm_prompt_updates,
mock_tokenizer,
)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("mm_prompt_updates:", mm_prompt_updates)
print("new_prompt:", new_prompt)
print("result:", result)
# Manually constructed results
assert result == expected
assert new_prompt == expected
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
# Tokenized test cases of `test_find_replace_text`
# Tokenized test cases of `test_find_update_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
@ -726,32 +727,28 @@ def test_find_update_tokens(
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
for mm_count, expected in expected_by_mm_count.items():
result = apply_token_matches(
mm_prompt_updates = {
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
for key, target in target_by_key.items()
}
new_prompt, result = apply_token_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
mm_prompt_updates,
mock_tokenizer,
)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("mm_prompt_updates:", mm_prompt_updates)
print("new_prompt:", new_prompt)
print("result:", result)
# Manually constructed results
assert result == expected
assert new_prompt == expected
# yapf: disable
@ -878,17 +875,11 @@ def test_find_mm_placeholders(
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_updates = {
key: [update_type(key, [], repl).bind(mock_tokenizer)]
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
for key, repl in repl_by_key.items()
}
result = find_mm_placeholders(
mm_prompt_updates,
prompt,
# Effectively match all occurrences in the prompt
{key: 3
for key in repl_by_key},
)
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
# Only displayed on error
print("result:", result)

View File

@ -10,14 +10,15 @@ from tests.v1.attention.utils import (BatchSpec, _Backend,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
set_kv_cache_layout)
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
_Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN,
"FLEX_ATTENTION_SLOW"
]
# Remove flashinfer from the list if it's not available
@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache(
common_attn_metadata: CommonAttentionMetadata,
randomize_blocks: bool = True) -> torch.Tensor:
"""Create and prepopulate a KV cache with context data.
Args:
k_contexts: List of key context tensors for each sequence
v_contexts: List of value context tensors for each sequence
@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache(
device: Device to create the cache on
num_blocks: Total number of blocks in the cache
block_table: Block table tensor to populate
randomize_blocks: Whether to randomly permute blocks
randomize_blocks: Whether to randomly permute blocks
or use sequential order
Returns:
Tuple of (kv_cache, updated_block_table)
"""
@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
kv_cache: torch.Tensor) -> torch.Tensor:
"""Run attention computation using the specified backend's AttentionImpl."""
builder_cls, impl_cls = get_attention_backend(backend)
# Handle special case for FLEX_ATTENTION_SLOW
actual_backend = backend
use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
if backend == "FLEX_ATTENTION_SLOW":
actual_backend = _Backend.FLEX_ATTENTION
use_direct_block_mask = False
builder_cls, impl_cls = get_attention_backend(actual_backend)
# Mock flashinfer's get_per_layer_parameters if needed
if backend == _Backend.FLASHINFER_VLLM_V1:
if actual_backend == _Backend.FLASHINFER_VLLM_V1:
import unittest.mock
from vllm.v1.attention.backends.utils import PerLayerParameters
@ -239,6 +248,8 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
else:
# Build metadata
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
if actual_backend == _Backend.FLEX_ATTENTION:
builder.direct_build = use_direct_block_mask
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
@ -453,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str):
rtol = 1e-2
atol = 5e-3
if backend_name == _Backend.FLEX_ATTENTION:
atol = 5e-1 # TODO: figure out why flex_attention has such large
# numerical differences for medium_decode, medium_prefill,
# mixed_medium
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
max_rel_diff = torch.max(
torch.abs(backend_output - sdpa_output) /

View File

@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
from types import SimpleNamespace
import pytest
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.short_conv import ShortConv
from vllm.model_executor.models.minimax_text_01 import (
MiniMaxText01LinearAttention)
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
@pytest.mark.parametrize(
"layer_class, init_kwargs, expected_backend, expected_mamba_type", [
(
MambaMixer,
dict(
hidden_size=128,
ssm_state_size=16,
conv_kernel_size=4,
intermediate_size=256,
time_step_rank=8,
use_conv_bias=True,
use_bias=False,
use_rms_norm=True,
),
Mamba1AttentionBackend,
"mamba1",
),
(
MambaMixer2,
dict(
hidden_size=128,
ssm_state_size=16,
conv_kernel_size=4,
intermediate_size=256,
use_conv_bias=True,
use_bias=False,
n_groups=1,
num_heads=8,
head_dim=32,
),
Mamba2AttentionBackend,
"mamba2",
),
(
MiniMaxText01LinearAttention,
dict(
hidden_size=128,
hidden_inner_size=256,
num_heads=8,
head_dim=32,
max_position=2048,
block_size=64,
num_hidden_layer=12,
layer_idx=0,
linear_layer_idx=0,
),
LinearAttentionBackend,
"linear_attention",
),
(
ShortConv,
dict(
config=SimpleNamespace(conv_L_cache=32, conv_bias=True),
dim=128,
layer_idx=0,
),
ShortConvAttentionBackend,
"short_conv",
),
])
def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
expected_backend, expected_mamba_type):
"""Test that Mamba-like layers return the correct attention backend."""
layer = layer_class(**init_kwargs)
backend_class = layer.get_attn_backend()
assert backend_class is expected_backend
assert layer.mamba_type == expected_mamba_type
@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
(ShortConv, ShortConvAttentionBackend, "short_conv"),
])
def test_mamba_layers_have_unified_interface(layer_class, expected_backend,
expected_mamba_type):
"""Test that all Mamba layers have the unified get_attn_backend
interface."""
assert hasattr(layer_class, 'get_attn_backend'), (
f"{layer_class.__name__} should have get_attn_backend method")
assert hasattr(layer_class, 'mamba_type'), (
f"{layer_class.__name__} should have mamba_type property")

View File

@ -1,25 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mamba attention backend selectors."""
import pytest
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"],
argvalues=[("mamba2", Mamba2AttentionBackend)])
def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend):
backend_class = get_mamba_attn_backend(mamba_type)
assert backend_class is expected_backend
def test_get_mamba_attn_backend_unsupported():
unsupported_types = ["mamba", ""]
for mamba_type in unsupported_types:
err_message = f"Mamba Attention type {mamba_type} is not supported yet."
with pytest.raises(NotImplementedError, match=err_message):
get_mamba_attn_backend(mamba_type)

View File

@ -0,0 +1,144 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
# ------------------ Mock Classes ------------------ #
class MockRequest:
def __init__(self, request_id, mm_hashes, token_counts):
self.request_id = request_id
self.mm_hashes = mm_hashes
self._token_counts = token_counts
def get_num_encoder_tokens(self, input_id: int) -> int:
return self._token_counts[input_id]
# ------------------ Unit Tests ------------------ #
def test_basic_allocate_and_reuse():
cache = EncoderCacheManager(cache_size=10)
req = MockRequest("r1", ["imgA"], [4])
assert not cache.check_and_update_cache(req, 0)
assert cache.try_allocate(req, 0, int(1e9))
cache.allocate(req, 0)
assert cache.check_and_update_cache(req, 0)
assert "r1" in cache.cached["imgA"]
assert cache.num_free_slots == 6
# Free twice to bring refcount to 0.
cache.free_encoder_input(req, 0)
cache.free_encoder_input(req, 0)
assert not cache.cached["imgA"]
assert "imgA" in cache.freeable
assert cache.num_freeable_slots == 10
assert cache.num_free_slots == 6
def test_freeing_decreases_refcount_and_moves_to_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req2", ["img3"], [5])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert len(manager.cached["img3"]) == 1
manager.free_encoder_input(req, 0)
assert not manager.cached["img3"]
assert "img3" in manager.freeable
assert manager.num_freeable_slots == 10
def test_free_request_frees_all_inputs():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("req3", ["a", "b"], [2, 3])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert manager.try_allocate(req, 1, int(1e9))
manager.allocate(req, 1)
assert len(manager.cached["a"]) == 1
assert len(manager.cached["b"]) == 1
manager.free(req)
assert not manager.cached["a"]
assert not manager.cached["b"]
assert "a" in manager.freeable
assert "b" in manager.freeable
assert manager.num_freeable_slots == 10
def test_eviction_when_cache_is_full():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("req1", ["x"], [6])
req2 = MockRequest("req2", ["y"], [5])
assert manager.try_allocate(req1, 0, int(1e9))
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
assert manager.try_allocate(req2, 0, int(1e9))
manager.allocate(req2, 0)
# 'x' should have been evicted.
assert "x" not in manager.cached
assert "x" in manager.get_freed_mm_hashes()
def test_get_cached_input_ids():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
assert manager.try_allocate(req, 2, int(1e9))
manager.allocate(req, 2)
cached_ids = manager.get_cached_input_ids(req)
assert cached_ids == {0, 2}
def test_has_cache_restores_from_freeable():
manager = EncoderCacheManager(cache_size=10)
req = MockRequest("reqY", ["imgZ"], [4])
assert manager.try_allocate(req, 0, int(1e9))
manager.allocate(req, 0)
manager.free_encoder_input(req, 0)
# Should restore from freeable.
assert manager.check_and_update_cache(req, 0)
assert len(manager.cached["imgZ"]) == 1
assert "imgZ" not in manager.freeable
assert manager.num_freeable_slots == 6
def test_get_freed_mm_hashes_clears_freed_list():
manager = EncoderCacheManager(cache_size=10)
req1 = MockRequest("reqA", ["a"], [5])
req2 = MockRequest("reqB", ["b"], [6])
assert manager.try_allocate(req1, 0, int(1e9))
manager.allocate(req1, 0)
manager.free_encoder_input(req1, 0)
# Should trigger eviction of 'a'.
assert manager.try_allocate(req2, 0, int(1e9))
manager.allocate(req2, 0)
freed = manager.get_freed_mm_hashes()
assert "a" in freed
assert manager.get_freed_mm_hashes() == []

View File

@ -338,7 +338,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
@ -391,7 +391,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -443,7 +443,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -490,7 +490,7 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)

View File

@ -143,7 +143,11 @@ def create_requests(
mm_position = mm_positions[i]
mm_item = MultiModalKwargsItem.dummy("dummy_m")
mm_kwargs = [mm_item] * len(mm_position)
mm_hashes = ["hash"] * len(mm_position)
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
mm_hashes = [
"hash" + str(i) + "_" + str(j) for j in range(len(mm_position))
]
else:
mm_position = None
mm_kwargs = None

View File

@ -41,8 +41,11 @@ EAGLE_SPEC_CONFIG = {
PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto",
None),
("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None),
("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None),
("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None),
("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto",
@ -148,7 +151,8 @@ def test_structured_output(
generated_text = output.outputs[0].text
assert generated_text is not None
assert "\n" not in generated_text
if guided_decoding_backend != 'lm-format-enforcer':
assert "\n" not in generated_text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@ -225,7 +229,7 @@ def test_structured_output(
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
if guided_decoding_backend != "outlines":
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
#
# Test 4: Generate SQL statement using EBNF grammar
#
@ -439,7 +443,7 @@ def test_structured_output(
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema)
if guided_decoding_backend != "outlines":
if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]:
#
# Test 11: Generate structured output using structural_tag format
#

View File

@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)

View File

@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int):
pooling_params=None,
mm_kwargs=[],
mm_positions=[],
mm_hashes=[],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),

View File

@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None,
)

View File

@ -18,6 +18,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -54,7 +55,7 @@ def check_xformers_availability():
return USE_XFORMERS_OPS
class Attention(nn.Module):
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors

View File

@ -11,18 +11,21 @@ generation. Supported dataset types include:
- HuggingFace
- VisionArena
"""
import ast
import base64
import io
import json
import logging
import math
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
from collections.abc import Iterator, Mapping
from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass
from functools import cache
from io import BytesIO
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, cast
import numpy as np
from PIL import Image
@ -114,7 +117,9 @@ class BenchmarkDataset(ABC):
def apply_multimodal_chat_transformation(
self,
prompt: str,
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
mm_content: Optional[
Union[MultiModalDataDict, dict, list[dict]]
] = None) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
@ -122,7 +127,15 @@ class BenchmarkDataset(ABC):
"""
content = [{"text": prompt, "type": "text"}]
if mm_content is not None:
content.append(mm_content)
if isinstance(mm_content, list):
content.extend(cast(list[dict[str, Any]], mm_content))
elif isinstance(mm_content, dict):
content.append(mm_content)
else:
raise TypeError(
"Could not process multimodal content of type: " +
f"{type(mm_content)}"
)
return [{"role": "user", "content": content}]
def load_data(self) -> None:
@ -362,90 +375,536 @@ def process_video(video: Any) -> Mapping[str, Any]:
class RandomDataset(BenchmarkDataset):
"""
Synthetic text-only dataset for serving/throughput benchmarks.
Strategy:
- Sample input/output token lengths per request from integer-uniform ranges
around configured means (controlled by range_ratio).
- Prepend a fixed random prefix of length prefix_len.
- Generate the remaining tokens as a reproducible sequence:
(offset + index + arange(input_len)) % vocab_size.
- Decode then re-encode/truncate to ensure prompt token counts match.
- Uses numpy.default_rng seeded with random_seed for reproducible sampling.
"""
# Default values copied from benchmark_serving.py for the random dataset.
DEFAULT_PREFIX_LEN = 0
DEFAULT_RANGE_RATIO = 0.0
DEFAULT_INPUT_LEN = 1024
DEFAULT_OUTPUT_LEN = 128
def __init__(
self,
**kwargs,
) -> None:
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
random.seed(self.random_seed)
np.random.seed(self.random_seed)
# Use numpy's default_rng for deterministic sampling
# Do not use random.seed() or np.random.seed() elsewhere in this class.
# This ensures that the RNG is isolated from global RNG state.
self._rng = np.random.default_rng(self.random_seed)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
# Enforce range_ratio < 1
assert range_ratio < 1.0, (
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
num_special_tokens = tokenizer.num_special_tokens_to_add()
real_input_len = input_len - num_special_tokens
prefix_token_ids = (np.random.randint(
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
# New sampling logic: [X * (1 - b), X * (1 + b)]
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
output_high = int(output_len * (1 + range_ratio))
# Add logging for debugging
logger.info(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
input_low, input_high, output_low, output_high)
input_lens = np.random.randint(input_low,
input_high + 1,
size=num_requests)
output_lens = np.random.randint(output_low,
output_high + 1,
size=num_requests)
offsets = np.random.randint(0, vocab_size, size=num_requests)
requests = []
for i in range(num_requests):
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
prompt = tokenizer.decode(token_sequence)
# After decoding the prompt we have to encode and decode it again.
# This is done because in some cases N consecutive tokens
# give a string tokenized into != N number of tokens.
# For example for GPT2Tokenizer:
# [6880, 6881] -> ['Ġcalls', 'here'] ->
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
# To avoid uncontrolled change of the prompt length,
# the encoded sequence is truncated before being decode again.
total_input_len = prefix_len + int(input_lens[i])
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:total_input_len]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = len(re_encoded_sequence)
prompt, total_input_len = self.generate_token_sequence(
tokenizer=tokenizer,
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
vocab_size=vocab_size,
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
)
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
))
)
)
return requests
def get_prefix(
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
self._rng.integers(
0, tokenizer.vocab_size, size=prefix_len).tolist()
if prefix_len > 0
else []
)
def get_sampling_params(
self,
num_requests: int,
range_ratio: float,
input_len: int,
output_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Get the sampling parameters for the dataset.
"""
# Enforce range_ratio < 1
if not (0.0 <= range_ratio < 1.0):
raise ValueError("range_ratio must be in [0, 1).")
num_special_tokens = int(tokenizer.num_special_tokens_to_add())
real_input_len = max(0, int(input_len) - num_special_tokens)
# Bounds use floor for low and ceil for high
input_low = math.floor(real_input_len * (1 - range_ratio))
input_high = math.ceil(real_input_len * (1 + range_ratio))
output_low = math.floor(output_len * (1 - range_ratio))
output_high = math.ceil(output_len * (1 + range_ratio))
# Ensure the lower bound for output length is at least 1 to
# prevent sampling 0 tokens.
output_low = max(output_low, 1)
if input_low > input_high:
raise ValueError(
"Invalid input sampling interval: "
f"low={input_low} > high={input_high}"
)
if output_low > output_high:
raise ValueError(
"Invalid output sampling interval: "
f"low={output_low} > high={output_high}"
)
logger.info(
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
input_low,
input_high,
output_low,
output_high,
)
input_lens = self._rng.integers(input_low, input_high + 1,
size=num_requests)
output_lens = self._rng.integers(output_low, output_high + 1,
size=num_requests)
offsets = self._rng.integers(0, tokenizer.vocab_size,
size=num_requests)
return input_lens, output_lens, offsets
def generate_token_sequence(
self,
*,
tokenizer: PreTrainedTokenizerBase,
prefix_token_ids: list[int],
prefix_len: int,
vocab_size: int,
input_len: int,
offset: int,
index: int,
) -> tuple[str, int]:
"""
Returns (prompt, total_input_len).
NOTE: After decoding the prompt we have to encode and decode it again.
This is done because in some cases N consecutive tokens
give a string tokenized into != N number of tokens.
For example for GPT2Tokenizer:
[6880, 6881] -> ['Ġcalls', 'here'] ->
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
To avoid uncontrolled change of the prompt length,
the encoded sequence is truncated before being decode again.
"""
# Build the inner sequence by sampling sequentially from the vocab
inner_seq = ((offset + index + np.arange(input_len))
% vocab_size).tolist()
token_sequence = prefix_token_ids + inner_seq
# Decode, then re-encode and truncate to preserve token count invariants
prompt = tokenizer.decode(token_sequence)
total_input_len = prefix_len + int(input_len)
re_encoded_sequence = tokenizer.encode(
prompt, add_special_tokens=False)[:total_input_len]
prompt = tokenizer.decode(re_encoded_sequence)
total_input_len = len(re_encoded_sequence)
return prompt, total_input_len
# -----------------------------------------------------------------------------
# MultiModalDataset Implementation
# -----------------------------------------------------------------------------
class RandomMultiModalDataset(RandomDataset):
"""
Synthetic multimodal dataset (text + images) that extends RandomDataset.
Status:
- Images: supported via synthetic RGB data.
- Video: not yet supported (TODO: implement video generation method).
- Audio: not yet supported.
Sampling overview:
1) Number of items per request is sampled uniformly from the integer range
[floor(n·(1r)), ceil(n·(1+r))], where n is the base count and r is
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
The maximum is further clamped to the sum of per-modality limits.
2) Each items modality and shape is sampled from `bucket_config`, a dict
mapping (height, width, num_frames) probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized
to sum to 1.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
When a modality reaches its cap, all of its buckets are excluded and the
remaining probabilities are renormalized.
Example bucket configuration:
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
- Two image buckets (`num_frames`=1) and one video bucket
(`num_frames`=16).
OBS.: Only image sampling is supported for now.
"""
IS_MULTIMODAL = True
# NOTE: video sampling is WIP. Setting it to 0.
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0}
DEFAULT_BASE_ITEMS_PER_REQUEST = 1
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
DEFAULT_MM_ITEM_BUCKET_CONFIG = {
(256, 256, 1): 0.5,
(720, 1280, 1): 0.5,
(720, 1280, 16): 0.0,
}
DEFAULT_ENABLE_MULTIMODAL_CHAT = False
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
"""Generate synthetic PIL image with random RGB values.
NOTE: iid pixel sampling results in worst-case compression
(good for stressing I/O), but very unlike real photos.
We could consider a low-freq mode (e.g., noise blur)
to emulate network realism instead of max stress.
"""
random_pixels = self._rng.integers(
0,
256,
(height, width, 3),
dtype=np.uint8,
)
return Image.fromarray(random_pixels)
def generate_synthetic_video(self, width: int,
height: int,
num_frames: int) -> Any:
"""Generate synthetic video with random values.
TODO: Finish this method.
"""
raise NotImplementedError("Video sampling is WIP.")
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
"""Map the configuration to the modality."""
if config[-1] == 1:
return "image"
elif config[-1] > 1:
return "video"
else:
raise ValueError(f"Invalid multimodal item configuration: {config}")
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
float]) -> dict[tuple[int, int, int], float]:
"""
Remove zero probability entries
and normalize the bucket config to sum to 1.
"""
# Raise error if value is negative
if any(v < 0 for v in bucket_config.values()):
raise ValueError("Bucket config values must be non-negative.")
# Remove zero probability entries
bucket_config = {k: v for k, v in bucket_config.items() if v > 0}
# if bucket config is empty, raise error
if not bucket_config:
raise ValueError("Got invalid bucket config. "
"Bucket config values must be non-zero.")
# Normalize the remaining bucket config to sum to 1
total = sum(bucket_config.values())
return {k: v / total for k, v in bucket_config.items()}
def generate_mm_item(self,
mm_item_config: tuple[int, int, int],
) -> Mapping[str, Any]:
"""
Create synthetic images and videos and
apply process_image/process_video respectively.
This follows the OpenAI API chat completions
https://github.com/openai/openai-python
"""
if self.map_config_to_modality(mm_item_config) == "image":
return process_image(self.generate_synthetic_image(
mm_item_config[1],
mm_item_config[0]))
elif self.map_config_to_modality(mm_item_config) == "video":
return process_video(self.generate_synthetic_video(
mm_item_config[1],
mm_item_config[0],
mm_item_config[2]))
else:
raise ValueError(f"Invalid multimodal item configuration: "
f"{mm_item_config}")
def get_mm_item_sampling_params(
self,
base_items_per_request: int,
num_mm_items_range_ratio: float,
limit_mm_per_prompt: dict[str, int],
bucket_config: dict[tuple[int, int, int], float],
) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]:
"""
Get the sampling parameters for the multimodal items.
"""
# Enforce num_mm_items_range_ratio <= 1
if not (0.0 <= num_mm_items_range_ratio <= 1.0):
raise ValueError("num_mm_items_range_ratio must be in [0, 1].")
# Ensure modalities to sample are in limit_mm_per_prompt
for k, v in bucket_config.items():
# get modality from bucket config
modality = self.map_config_to_modality(k)
if modality not in limit_mm_per_prompt:
raise ValueError(f"Modality {modality} is not in "
f"limit_mm_per_prompt: "
f"{limit_mm_per_prompt.keys()}")
# Remove zero probability entries
# and normalize bucket config to sum to 1
bucket_config = self.normalize_bucket_config(bucket_config)
logger.info(
"Normalized bucket config: %s", bucket_config,
)
# Only consider limit per prompt for modalities in bucket config
allowed_modalities = {self.map_config_to_modality(cfg)
for cfg in bucket_config}
limit_mm_per_prompt = {
k: v for k, v in limit_mm_per_prompt.items()
if k in allowed_modalities}
if not limit_mm_per_prompt:
raise ValueError("No valid limits for modalities present in "
"bucket_config.")
logger.info(
"Updated mm-limit-per-prompt: %s", limit_mm_per_prompt,
)
# Get max and min num mm items and ensure
# it is at most the sum of limit_mm_per_prompt for all modalities
max_num_mm_items = min(
sum(limit_mm_per_prompt.values()),
math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
)
# Ensure min num mm items is at least 0
min_num_mm_items = max(
0,
math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
)
# Raise error if min num mm items is greater than max num mm items
if min_num_mm_items > max_num_mm_items:
raise ValueError(f"Min num mm items is greater than max mm items: "
f"{min_num_mm_items} > {max_num_mm_items}")
logger.info(
"Sampling number of multimodal items from [%s, %s]",
min_num_mm_items, max_num_mm_items,
)
return (
min_num_mm_items,
max_num_mm_items,
limit_mm_per_prompt,
bucket_config,
)
def get_mm_item_iterator(
self,
min_num_mm_items: int,
max_num_mm_items: int,
bucket_config: dict[tuple[int, int, int], float],
limit_mm_per_prompt: dict[str, int],
) -> Iterator[tuple[int,int, int]]:
"""
Iterator over the multimodal items for each request
whose size is between min_num_mm_items and max_num_mm_items.
Loop over the bucket config and sample a multimodal item.
Loop until the number of multimodal items sampled is equal to
request_num_mm_items or limit of multimodal items per prompt
for all modalities is reached.
Note:
- This function operates on a per-request shallow copy of
`bucket_config` (tuple->float). The original dict passed to
`sample` is not mutated. If this ever changes, a test
is implemented and will fail.
"""
# Get the number of multimodal items to sample
request_num_mm_items = int(
self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
)
# If request_num_mm_items is 0, yield an empty iterator
if request_num_mm_items == 0:
return
# Initialize modality counters
modality_counter = {self.map_config_to_modality(k): 0
for k in bucket_config}
# Copy the bucket config to avoid modifying the original
bucket_config_copy = bucket_config.copy()
# Loop over the number of multimodal items to sample
while sum(modality_counter.values()) < request_num_mm_items:
# Sample a multimodal item config
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
p=list(bucket_config_copy.values()))
modality = self.map_config_to_modality(mm_item_config)
# Check that modality count is less than limit per prompt
if modality_counter[modality] < limit_mm_per_prompt[modality]:
modality_counter[modality] += 1
yield (
mm_item_config
)
else:
# If the counter is greater than the limit per prompt
# set all multimodal items of this modality to 0
for k, v in bucket_config_copy.items():
if self.map_config_to_modality(k) == modality:
bucket_config_copy[k] = 0
# If all configs are 0, break the loop
# This should not happen as request_num_mm_items is at most
# the sum of limit_mm_per_prompt for all modalities
if all(v == 0 for v in bucket_config_copy.values()):
logger.warning("Exhausted all multimodal items "
"of modality %s",
modality)
break
# Renormalize the bucket config
bucket_config_copy = self.normalize_bucket_config(
bucket_config_copy)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
request_id_prefix: str = "",
prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN,
limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
bucket_config: dict[tuple[int, int, int], float] =
DEFAULT_MM_ITEM_BUCKET_CONFIG,
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
**kwargs,
) -> list[SampleRequest]:
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
# and probability is non-zero.
if any(self.map_config_to_modality(cfg) == "video" and p > 0
for cfg, p in bucket_config.items()):
raise NotImplementedError("Video sampling not implemented; "
"set its probability to 0.")
# Get the sampling parameters for the dataset
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
)
(
min_num_mm_items,
max_num_mm_items,
limit_mm_per_prompt,
bucket_config,
) = self.get_mm_item_sampling_params(
base_items_per_request,
num_mm_items_range_ratio,
limit_mm_per_prompt,
bucket_config,
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
# Add synthetic multimodal items to each request
mm_requests = []
for i in range(num_requests):
prompt, total_input_len = self.generate_token_sequence(
tokenizer=tokenizer,
prefix_token_ids=prefix_token_ids,
prefix_len=prefix_len,
vocab_size=vocab_size,
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
)
# Get multimodal item iterator for a given request
mm_item_iterator = self.get_mm_item_iterator(
min_num_mm_items,
max_num_mm_items,
bucket_config,
limit_mm_per_prompt,
)
mm_content = cast(list[dict[str, Any]], [
self.generate_mm_item(mm_item_config)
for mm_item_config in mm_item_iterator
])
if enable_multimodal_chat:
# NOTE: For now this option is only provided for completeness
# given that the serve.py benchmark currently does not use it.
mm_chat_prompt: Any = prompt
mm_chat_prompt = self.apply_multimodal_chat_transformation(
prompt, mm_content)
sample_request = SampleRequest(
prompt=mm_chat_prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
multi_modal_data=None,
request_id=request_id_prefix + str(i),
)
else:
sample_request = SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
multi_modal_data=mm_content,
request_id=request_id_prefix + str(i),
)
mm_requests.append(sample_request)
return mm_requests
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
@ -545,8 +1004,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
type=str,
default="random",
choices=[
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
"prefix_repetition"
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
"custom", "prefix_repetition"
],
help="Name of the dataset to benchmark on.",
)
@ -647,6 +1106,98 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"input_len * (1 + range_ratio)]."),
)
# random multimodal dataset options
random_mm_group = parser.add_argument_group(
"random multimodal dataset options extended from random dataset")
random_mm_group.add_argument(
"--random-mm-base-items-per-request",
type=int,
default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
help=(
"Base number of multimodal items per request for random-mm. "
"Actual per-request count is sampled around this base using "
"--random-mm-num-mm-items-range-ratio."
),
)
random_mm_group.add_argument(
"--random-mm-num-mm-items-range-ratio",
type=float,
default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
help=(
"Range ratio r in [0, 1] for sampling items per request. "
"We sample uniformly from the closed integer range "
"[floor(n*(1-r)), ceil(n*(1+r))] "
"where n is the base items per request. "
"r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped "
"to the sum of per-modality limits from "
"--random-mm-limit-mm-per-prompt. "
"An error is raised if the computed min exceeds the max."
),
)
random_mm_group.add_argument(
"--random-mm-limit-mm-per-prompt",
type=json.loads,
default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
help=(
"Per-modality hard caps for items attached per request, e.g. "
"'{\"image\": 3, \"video\": 0}'. The sampled per-request item "
"count is clamped to the sum of these limits. When a modality "
"reaches its cap, its buckets are excluded and probabilities are "
"renormalized."
"OBS.: Only image sampling is supported for now."
),
)
def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]:
# If already a dict (e.g., programmatic call), normalize keys
def normalize(d: dict) -> dict[tuple[int, int, int], float]:
out: dict[tuple[int, int, int], float] = {}
for k, val in d.items():
key = k
if isinstance(key, str):
with suppress(Exception):
key = ast.literal_eval(key)
if not (isinstance(key, tuple) and len(key) == 3
and all(isinstance(x, int) for x in key)):
raise ValueError(
f"Invalid bucket key {k!r}. Expected tuple (H, W, T)."
)
out[(int(key[0]), int(key[1]), int(key[2]))] = float(val)
return out
if isinstance(v, dict):
return normalize(v)
if isinstance(v, str):
# Python literal (supports tuple keys)
parsed = ast.literal_eval(v)
if not isinstance(parsed, dict):
raise ValueError("Bucket config must parse to a dict.")
return normalize(parsed)
raise ValueError("Unsupported value for --random-mm-bucket-config.")
random_mm_group.add_argument(
"--random-mm-bucket-config",
type=_parse_mm_bucket_config,
default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
help=(
"The bucket config is a dictionary mapping a multimodal item"
"sampling configuration to a probability."
"Currently allows for 2 modalities: images and videos. "
"An bucket key is a tuple of (height, width, num_frames)"
"The value is the probability of sampling that specific item. "
"Example: "
"--random-mm-bucket-config "
"{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} "
"First item: images with resolution 256x256 w.p. 0.5"
"Second item: images with resolution 720x1280 w.p. 0.4 "
"Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1"
"OBS.: If the probabilities do not sum to 1, they are normalized."
"OBS bis.: Only image sampling is supported for now."
),
)
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
@ -821,6 +1372,22 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
),
"random-mm":
lambda: RandomMultiModalDataset(
random_seed=args.seed, dataset_path=args.dataset_path
).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
prefix_len=args.random_prefix_len,
range_ratio=args.random_range_ratio,
input_len=args.random_input_len,
output_len=args.random_output_len,
base_items_per_request=args.random_mm_base_items_per_request,
limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
bucket_config=args.random_mm_bucket_config,
request_id_prefix=args.request_id_prefix,
),
"prefix_repetition":
lambda: PrefixRepetitionRandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
@ -836,6 +1403,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
}
try:
# Enforce endpoint compatibility for multimodal datasets.
if args.dataset_name == "random-mm" and args.endpoint_type not in [
"openai-chat"]:
raise ValueError(
"Multi-modal content (images) is only supported on "
"'openai-chat' backend."
)
input_requests = dataset_mapping[args.dataset_name]()
except KeyError as err:
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err

View File

@ -3057,7 +3057,8 @@ def get_served_model_name(model: str,
return served_model_name
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"]
GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines",
"lm-format-enforcer"]
@config

View File

@ -663,9 +663,9 @@ class OpenAIServingChat(OpenAIServing):
harmony_parser = harmony_parsers[i]
for token_id in output.token_ids:
harmony_parser.process(token_id)
# FIXME(woosuk): Support function calling
is_final = harmony_parser.current_channel == "final"
if not (request.include_reasoning or is_final):
is_reasoning = \
harmony_parser.current_channel == "analysis"
if not request.include_reasoning and is_reasoning:
# Skip the reasoning content.
continue
delta_text = harmony_parser.last_content_delta or ""
@ -695,11 +695,11 @@ class OpenAIServingChat(OpenAIServing):
current_token_ids = as_list(output.token_ids)
if self.use_harmony:
if is_final:
delta_message = DeltaMessage(content=delta_text)
else:
if is_reasoning:
delta_message = DeltaMessage(
reasoning_content=delta_text)
else:
delta_message = DeltaMessage(content=delta_text)
# handle streaming deltas for tools with named tool_choice
elif tool_choice_function_name:
if (self.reasoning_parser and not reasoning_end_arr[i]

View File

@ -64,8 +64,8 @@ class DeepSeekV31ToolParser(ToolParser):
if (self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None):
raise RuntimeError(
"DeepSeek-V3 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
"DeepSeek-V3.1 Tool parser could not locate tool call "
"start/end tokens in the tokenizer!")
def extract_tool_calls(
self,

View File

@ -422,12 +422,23 @@ _ACTIVATION_REGISTRY = LazyDict({
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
"tanh":
lambda: nn.Tanh(),
"sigmoid":
lambda: nn.Sigmoid(),
})
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name.startswith("torch.nn.modules."):
activation_name = act_fn_name.split(".")[-1]
if activation_name == "identity":
return nn.Identity()
act_fn_name = activation_name
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")

View File

@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this layer."""
pass

View File

@ -1378,7 +1378,7 @@ class RowParallelLinear(LinearBase):
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s = f"in_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"

View File

@ -1,12 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
class MambaBase(ABC):
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
@ -32,3 +38,8 @@ class MambaBase(ABC):
@abstractmethod
def mamba_type(self) -> str:
pass
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
@ -404,6 +407,11 @@ class MambaMixer(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba1"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import (
Mamba1AttentionBackend)
return Mamba1AttentionBackend
def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
@ -758,6 +761,11 @@ class MambaMixer2(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
return Mamba2AttentionBackend
def mamba_mixer2(
hidden_states: torch.Tensor,

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
@ -232,6 +235,11 @@ class ShortConv(MambaBase, CustomOp):
def mamba_type(self) -> str:
return "short_conv"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
return ShortConvAttentionBackend
def short_conv(
hidden_states: torch.Tensor,

View File

@ -5,7 +5,7 @@ from collections.abc import Mapping, Set
from dataclasses import dataclass
from enum import IntEnum
from itertools import groupby
from typing import Callable, Optional, TypeVar, Union
from typing import Callable, Optional, TypeVar, Union, cast
import torch
import torch.nn as nn
@ -435,9 +435,31 @@ class EmbeddingPoolerHead(PoolerHead):
def __init__(self) -> None:
super().__init__(activation=PoolerNormalize())
# Load ST projector if available
from vllm.config import get_current_vllm_config
from vllm.model_executor.models.adapters import _load_st_projector
vllm_config = get_current_vllm_config()
self.projector = _load_st_projector(
vllm_config.model_config) if vllm_config else None
def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
pooling_metadata: PoolingMetadata):
# Apply ST projector
if self.projector is not None:
projector = cast(nn.Module, self.projector)
def _proj(x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
y = projector(x.to(torch.float32))
return y.to(orig_dtype)
if isinstance(pooled_data, torch.Tensor):
pooled_data = _proj(pooled_data)
else:
pooled_data = [_proj(t) for t in pooled_data]
pooling_params = get_pooling_params(pooling_metadata)
if isinstance(pooled_data, list):

View File

@ -7,15 +7,21 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.config import VerifyAndUpdateConfig
from vllm.transformers_utils.config import (get_hf_file_bytes,
get_hf_file_to_dict)
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
_T = TypeVar("_T", bound=type[nn.Module])
logger = init_logger(__name__)
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
@ -24,6 +30,96 @@ _GENERATE_SUFFIXES = [
]
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
"""Load Sentence-Transformers Dense projection layers."""
try:
modules = get_hf_file_to_dict("modules.json", model_config.model,
model_config.revision)
if not modules:
return None
if isinstance(modules, dict):
modules = modules.get("modules", [])
dense_modules = [
m for m in modules
if m.get("type") == "sentence_transformers.models.Dense"
]
if not dense_modules:
return None
module = dense_modules[0]
folder = module.get("path", "")
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(config_path, model_config.model,
model_config.revision)
if not layer_config:
return None
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=torch.float32)
if _load_dense_weights(linear, folder, model_config):
layers = [linear]
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=torch.float32)
except Exception:
logger.exception("ST projector loading failed")
return None
def _load_dense_weights(linear: nn.Linear, folder: str,
model_config: "ModelConfig") -> bool:
"""Load weights using vLLM's weight_loader pattern."""
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
for filename in ["model.safetensors", "pytorch_model.bin"]:
file_path = f"{folder}/{filename}" if folder else filename
try:
file_bytes = get_hf_file_bytes(file_path, model_config.model,
model_config.revision)
if not file_bytes:
continue
if filename.endswith(".safetensors"):
from safetensors.torch import load as load_safetensors
state_dict = load_safetensors(file_bytes)
else:
import io
state_dict = torch.load(io.BytesIO(file_bytes),
map_location="cpu",
weights_only=True)
for weight_key in ["weight", "linear.weight", "dense.weight"]:
if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader",
default_weight_loader)
weight_loader(linear.weight,
state_dict[weight_key].to(torch.float32))
bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader",
default_weight_loader)
bias_loader(linear.bias,
state_dict[bias_key].to(torch.float32))
return True
except Exception:
logger.exception("Failed to load %s", filename)
continue
return False
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name

View File

@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptIndexTargets, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.utils.tensor_schema import TensorSchema, TensorShape
class MBartDecoderWrapper(nn.Module):
@ -132,10 +133,16 @@ class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
return loaded_params
class DonutImagePixelInputs(TypedDict):
class DonutImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channel, height, width)"""
data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
class DonutProcessingInfo(BaseProcessingInfo):
@ -275,27 +282,6 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
)
self.pad_token_id = config.pad_token_id
def _validate_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
# size = self.processor_config["size"]
h, w = self.config.encoder.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
raise ValueError(
"The expected shape of pixel values per batch "
f"is {expected_dims}. You supplied {actual_dims}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
@ -314,11 +300,14 @@ class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
return DonutImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
h, w = self.config.encoder.image_size
return DonutImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values,
concat=True),
resolve_bindings={
"h": h,
"w": w,
})
if image_embeds is not None:
raise NotImplementedError

View File

@ -22,11 +22,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate,
BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
find_mm_placeholders,
PromptReplacement, PromptUpdate,
PromptUpdateDetails,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
@ -337,14 +338,10 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
token_ids = super()._apply_token_matches(
prompt,
mm_matches,
mm_item_counts,
)
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
token_ids, res = super()._apply_token_matches(prompt,
mm_prompt_updates)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
@ -373,13 +370,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
[newline_4],
)
return token_ids
return token_ids, res
def _find_mm_placeholders(
self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer()
@ -404,8 +400,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
mm_item_counts)
repls = super()._find_mm_placeholders(repl_token_ids,
mm_prompt_updates)
return {
modality: [

View File

@ -29,11 +29,12 @@ from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
MultiModalDataParser)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate,
BaseProcessingInfo,
MultiModalPromptUpdates,
MultiModalPromptUpdatesApplyResult,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
find_mm_placeholders,
PromptReplacement, PromptUpdate,
PromptUpdateDetails,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
@ -254,14 +255,10 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
token_ids = super()._apply_token_matches(
prompt,
mm_matches,
mm_item_counts,
)
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]:
token_ids, res = super()._apply_token_matches(prompt,
mm_prompt_updates)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
@ -290,13 +287,12 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
[newline_4],
)
return token_ids
return token_ids, res
def _find_mm_placeholders(
self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer()
@ -321,8 +317,8 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
mm_item_counts)
repls = super()._find_mm_placeholders(repl_token_ids,
mm_prompt_updates)
return {
modality: [

View File

@ -828,26 +828,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
target=[image_token_id] * num_image_tokens,
replacement=get_replacement_mantis,
)
])
], mm_item_counts)
prompt_ids, prompt, _ = self._apply_prompt_updates(
result["prompt_token_ids"],
mantis_mm_repls,
mm_item_counts,
)
unbound_orig_repls = self._get_prompt_updates(
orig_repls = self._get_mm_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
orig_repls = self._bind_and_group_updates(unbound_orig_repls)
mm_placeholders = self._find_mm_placeholders(
orig_repls,
prompt_ids,
mm_item_counts,
)
mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_placeholder_ranges = {

View File

@ -75,7 +75,7 @@ class LlavaOnevisionImagePixelInputs(TensorSchema):
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w"),
TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}),
]
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]

View File

@ -4,7 +4,10 @@
import copy
import math
from collections.abc import Iterable
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import regex as re
import torch
@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import (
LinearAttentionBackend)
return LinearAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,

View File

@ -38,7 +38,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate,
BaseProcessingInfo,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate)
# yapf: enable
@ -431,24 +432,21 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
return [_IMAGE_TOKEN_ID] * num_image_tokens
num_images = mm_items.get_count("image", strict=False)
return [
PromptReplacement(
modality="image",
target=image_token,
target=image_tokens.__getitem__,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:num_images]
)
]
def _apply_prompt_updates(
self,
token_ids: list[int],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
# align to hf behavior when there are images
if len(mm_item_counts):
if len(mm_prompt_updates):
tokenizer = self.info.get_tokenizer()
# to decode token_ids to the original text, we need to
# 1. remove the first bos token
@ -484,7 +482,6 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids,
mm_prompt_updates=mm_prompt_updates,
mm_item_counts=mm_item_counts,
)
# Keep the behavior in line with HF processor

View File

@ -1032,8 +1032,8 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
image_token_id = tokenizer.vocab[tokenizer.image_token]
audio_token_id = tokenizer.vocab[tokenizer.audio_token]
image_token_id: int = tokenizer.vocab[tokenizer.image_token]
audio_token_id: int = tokenizer.vocab[tokenizer.audio_token]
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
audio_processor = self.info.get_feature_extractor(
@ -1053,9 +1053,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
processor=hf_processor,
)
image_tokens = [image_token_id] * num_image_tokens
return image_tokens
return [image_token_id] * num_image_tokens
def get_audio_replacement_phi4mm(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
@ -1066,9 +1064,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_embed_size = self.info._compute_audio_embed_size(
audio_frames)
audio_tokens = [audio_token_id] * audio_embed_size
return audio_tokens
return [audio_token_id] * audio_embed_size
return [
PromptReplacement(

View File

@ -824,9 +824,7 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
processor=hf_processor,
)
image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
return image_tokens
return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens
def get_audio_replacement_phi4mm(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
@ -837,28 +835,20 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
audio_embed_size = self.info._compute_audio_embed_size(
audio_frames)
audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size
return audio_tokens
num_images = mm_items.get_count("image", strict=False)
num_audios = mm_items.get_count("audio", strict=False)
image_repl = [
return [
PromptReplacement(
modality="image",
target=image_token,
target=image_tokens.__getitem__,
replacement=get_image_replacement_phi4mm,
) for image_token in image_tokens[:num_images]
]
audio_repl = [
),
PromptReplacement(
modality="audio",
target=audio_token,
target=audio_tokens.__getitem__,
replacement=get_audio_replacement_phi4mm,
) for audio_token in audio_tokens[:num_audios]
),
]
return image_repl + audio_repl
@MULTIMODAL_REGISTRY.register_processor(

View File

@ -309,9 +309,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
mm_prompt_updates,
prompt_ids,
mm_item_counts,
mm_prompt_updates,
)
self._validate_mm_placeholders(
mm_placeholders,
@ -328,7 +327,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
) = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
mm_item_counts,
)
self._validate_mm_placeholders(
mm_placeholders,

View File

@ -135,7 +135,7 @@ class Qwen2_5_VLVideoPixelInputs(TypedDict):
second_per_grid_ts: torch.Tensor
"""
The video time interval (in seconds) for each grid along the temporal
The video time interval (in seconds) for each grid along the temporal
dimension in the 3D position IDs. Returned when `videos` is not `None`.
"""
@ -852,6 +852,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP,
SupportsQuant):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
}
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={

View File

@ -45,6 +45,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
@ -146,11 +149,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate")
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=f"{prefix}.gate")
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
@ -682,4 +694,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
return self.model.get_expert_mapping()

View File

@ -8,7 +8,7 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -35,6 +35,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
@ -48,27 +49,42 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
class SkyworkR1VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor
class SkyworkR1VImagePixelInputs(TensorSchema):
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
Dimensions:
- bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height
- w: Width
- bn: Batch size * number of images
"""
type: Literal["pixel_values"] = "pixel_values"
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
pixel_values_flat: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "h", "w"),
]
num_patches: Annotated[
torch.Tensor,
TensorShape("bn"),
]
class SkyworkR1VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class SkyworkR1VImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- ni: Number of images
- ifs: Image feature size
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("ni", "ifs", "hs"),
]
SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
@ -731,26 +747,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
@ -788,10 +784,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return SkyworkR1VImagePixelInputs(
type="pixel_values",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
pixel_values_flat=pixel_values_flat,
num_patches=image_num_patches,
)
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,
})
raise AssertionError("This line should be unreachable.")

View File

@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast)
import torch
@ -34,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -43,14 +44,28 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
from .vision import VisionEncoderInfo, get_vision_encoder_info
class TarsierImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
class TarsierImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class TarsierImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
class TarsierImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
TarsierImageInputs = Union[TarsierImagePixelInputs,
@ -432,18 +447,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) # Assuming 3 channels
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[TarsierImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -459,8 +462,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
return TarsierImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
pixel_values=flatten_bn(pixel_values, concat=True),
)
if image_embeds is not None:

File diff suppressed because it is too large Load Diff

View File

@ -927,3 +927,25 @@ def get_model_path(model: Union[str, Path], revision: Optional[str] = None):
from huggingface_hub import snapshot_download
return snapshot_download(repo_id=model, **common_kwargs)
def get_hf_file_bytes(file_name: str,
model: Union[str, Path],
revision: Optional[str] = 'main') -> Optional[bytes]:
"""Get file contents from HuggingFace repository as bytes."""
file_path = try_get_local_file(model=model,
file_name=file_name,
revision=revision)
if file_path is None:
hf_hub_file = hf_hub_download(model,
file_name,
revision=revision,
token=_get_hf_token())
file_path = Path(hf_hub_file)
if file_path is not None and file_path.is_file():
with open(file_path, 'rb') as file:
return file.read()
return None

View File

@ -1,11 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from collections import defaultdict
"""Attention layer with FlexAttention."""
from dataclasses import dataclass
from typing import Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch._dynamo.decorators
import torch.nn.functional as F
from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
_score_mod_signature,
create_block_mask,
@ -16,13 +18,17 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
is_quantized_kv_cache)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
create_block_mask_compiled = torch.compile(create_block_mask,
fullgraph=True,
mode="reduce-overhead")
@ -36,6 +42,23 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
torch.arange(len(counts), device=device, dtype=torch.int32), counts)
def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
difference = (multiple - (x.shape[dim] % multiple)) % multiple
if difference == 0:
return x
dim = dim if dim >= 0 else x.ndim + dim
pad_list = []
for i in range(x.ndim - 1, dim - 1, -1):
if i == dim:
pad_list.extend([0, difference])
else:
pad_list.extend([0, 0])
return F.pad(x, pad_list, mode="constant", value=0)
class FlexAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@ -77,10 +100,10 @@ class FlexAttentionBackend(AttentionBackend):
return False
# @torch.compile(fullgraph=True, mode="reduce-overhead")
def physical_to_logical_mapping(
block_table: torch.Tensor,
total_blocks: Optional[int] = None) -> torch.Tensor:
#@torch.compile(fullgraph=True, mode="reduce-overhead")
def physical_to_logical_mapping(block_table: torch.Tensor,
seq_lens: torch.Tensor, block_size: int,
total_blocks: int) -> torch.Tensor:
"""
Creates an inverse mapping from physical block locations to logical indices.
@ -114,13 +137,38 @@ def physical_to_logical_mapping(
If a physical block is not mapped to by any logical block,
its value in the result will be -1.
IMPORTANT: Garbage Value Protection
The block_table tensor may contain garbage values in unused positions
(beyond the actual sequence length). For example, if a sequence only
needs 3 blocks but the table has space for 8:
block_table[0] = [10, 25, 7, 999, 1234, 888, ...]
^^^^^^^^^^^^^^^^^^^^
garbage values
These garbage values can cause issues because:
1. They may map to valid physical blocks by coincidence
2. The scatter_ operation will assign them logical indices
3. Later attention computations may incorrectly access these blocks
To prevent this, we use seq_lens and block_size to mask out unused
entries, ensuring only valid block references are processed.
Args:
block_table: Tensor of shape [max_reqs, max_num_blocks]
mapping logical blocks to physical locations
mapping logical blocks to physical locations. May contain
garbage values in unused positions.
seq_lens: Tensor of sequence lengths for each request. Used to
determine how many blocks are actually needed per sequence.
block_size: Size of each block in tokens. Used with seq_lens to
compute the number of valid blocks per sequence.
total_blocks: Total number of physical blocks available
Returns:
A tensor of shape [max_reqs, max_physical_block]
A tensor of shape [max_reqs, total_blocks] where each entry
physical_to_logical[req_id, physical_block] contains the logical
block index for that physical block, or -1 if unused.
"""
max_reqs, max_num_blocks = block_table.shape
device = block_table.device
@ -130,17 +178,76 @@ def physical_to_logical_mapping(
dtype=torch.long,
device=device)
logical_indices = (torch.arange(max_num_blocks,
device=device).unsqueeze(0).expand(
max_reqs, -1))
# Only process valid blocks to avoid garbage values
num_blocks_per_seq = cdiv(seq_lens, block_size)
mask = torch.arange(max_num_blocks,
device=device)[None, :] < num_blocks_per_seq[:, None]
physical_to_logical.scatter_(-1, block_table.to(torch.int64),
logical_indices)
# TODO Confirm - Seems like block 0 is always empty so we reset it manually
valid_block_table = torch.where(mask, block_table, 0)
valid_logical_indices = torch.where(
mask,
torch.arange(max_num_blocks, device=device)[None, :], 0)
physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64),
valid_logical_indices)
# NB - Seems like block 0 is always empty so we reset it manually
physical_to_logical[:, 0] = -1
return physical_to_logical
def unique_static_unsorted(
x: torch.Tensor,
*,
M: int, # maximum positive value (0 is “skip me”)
dim: int = -1, # axis along which to deduplicate
ignored_val: int = 0, # value to ignore
pad_val: int = -1, # sentinel for unused slots
) -> torch.Tensor:
"""
- Keeps the first occurrence of each non-zero value while preserving order,
then left-packs those uniques and fills the rest with `pad_val`.
- Returns (packed, keep_mask) with the *same shape* as `x`.
- Requires that all values be in the range [0, M]
- Skips ignored_val
Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory.
Example:
x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1]
"""
if not (-1 <= pad_val <= M):
raise ValueError("`pad_val` must lie in [-1, M]")
# ── move `dim` to the end so we can treat tensor as [B, N] ──────────
dim = dim % x.ndim
x_perm = x.movedim(dim, -1) # shape [..., N]
B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1]
x_flat = x_perm.reshape(B, N) # [B, N]
device = x.device
idx = torch.arange(N, device=device).expand(B, N) # per-row indices
# ── build first-occurrence table for every v ∈ [0, M] ───────────────
first_idx = torch.full((B, M + 1), N, device=device) # “∞”
# scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i)for each i
first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin")
# ── keep mask: first occurrence *and* value ≠ 0 ─────────────────────
keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)
) # [B, N]
# ── left-pack uniques into a fresh tensor ───────────────────────────
dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go
packed_flat = torch.full_like(x_flat, pad_val)
rows, src_cols = torch.nonzero(keep, as_tuple=True)
packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols]
# ── restore original layout ─────────────────────────────────────────
packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim)
return packed
def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
kv_idx: torch.Tensor):
return q_idx >= kv_idx
@ -170,6 +277,7 @@ class FlexAttentionMetadata:
num_reqs: int
physical_to_logical: torch.Tensor
decode_offset: torch.Tensor
num_blocks_per_seq: torch.Tensor
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
@ -179,6 +287,46 @@ class FlexAttentionMetadata:
block_mask: Optional[BlockMask] = None
score_mod: Optional[_score_mod_signature] = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod
doc_ids: Optional[torch.Tensor] = None
direct_build: bool = True
q_block_size: int = 16
kv_block_size: int = 16
transformed_score_mod: Optional[_score_mod_signature] = None
def _convert_physical_to_logical(
self,
request_lookup: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert physical indices to logical indices for both query and kv.
NB is_within_lower_bound: do sequences start on block_boundaries?
Returns:
tuple of (is_valid, logical_q_idx, logical_kv_idx)
"""
# Map query indices to corresponding request indices
q_req = request_lookup[q_idx]
# Convert physical KV indices to logical indices
physical_kv_block = physical_kv_idx // self.block_size
physical_kv_offset = physical_kv_idx % self.block_size
logical_block_idx = self.physical_to_logical[q_req, physical_kv_block]
logical_kv_idx = (logical_block_idx * self.block_size +
physical_kv_offset)
# Determine valid kv indices
live_block = logical_block_idx >= 0
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
within_lower_bound = logical_kv_idx >= 0
is_valid = live_block & within_upper_bound & within_lower_bound
# Convert physical query indices to logical indices
local_q_idx = q_idx - self.query_start_loc[q_req]
logical_q_idx = local_q_idx + self.decode_offset[q_req]
return is_valid, logical_q_idx, logical_kv_idx
def get_causal_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention.
@ -191,11 +339,8 @@ class FlexAttentionMetadata:
With this info we create the "logical" indices that are passed to
mask_mod functions. This allows mask mod functions to be agnostic to
layout of the query and key/value tensors.
TODO is_within_lower_bound: do sequences start on block_boundaries?
"""
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
assert self.doc_ids is not None
def final_mask_mod(
b: torch.Tensor,
@ -203,27 +348,9 @@ class FlexAttentionMetadata:
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
# Map query indices to corresponding request indices
q_req = request_lookup[q_idx]
# Convert physical KV indices to logical indices
physical_kv_block = physical_kv_idx // self.block_size
physical_kv_offset = physical_kv_idx % self.block_size
logical_block_idx = self.physical_to_logical[q_req,
physical_kv_block]
logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501
# Determine valid kv indices
live_block = logical_block_idx >= 0
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
within_lower_bound = logical_kv_idx >= 0
is_valid = live_block & within_upper_bound & within_lower_bound
# Convert physical query indices to logical indices
local_q_idx = q_idx - self.query_start_loc[q_req]
logical_q_idx = local_q_idx + self.decode_offset[q_req]
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
self.doc_ids, q_idx, physical_kv_idx)
# Apply mask modification only for valid indices
return torch.where(
is_valid,
@ -236,7 +363,7 @@ class FlexAttentionMetadata:
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
"""Creates the encoder mask_mod function for FlexAttention.
Since the encoder bidirectional attention doesn't run with
Since the encoder bidirectional attention doesn't run with
KV cache, this function creates a mask based on the
packed query sequences.
"""
@ -253,6 +380,97 @@ class FlexAttentionMetadata:
return final_mask_mod
def get_transformed_score_mod(self) -> Optional[_score_mod_signature]:
"""Creates the transformed score_mod function for FlexAttention.
This function wraps the user's score_mod to handle physical-to-logical
index conversion, similar to how get_mask_mod works for mask functions.
"""
if self.score_mod is None:
return None
# Create a lookup mapping from query indices -> request number
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
user_score_mod = self.score_mod
def transformed_score_mod(
score: torch.Tensor,
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
physical_kv_idx: torch.Tensor,
) -> torch.Tensor:
(is_valid, logical_q_idx,
logical_kv_idx) = self._convert_physical_to_logical(
request_lookup, q_idx, physical_kv_idx)
return torch.where(
is_valid,
user_score_mod(score,
b,
h,
logical_q_idx,
logical_kv_idx,
physical_q=q_idx), -float('inf'))
return transformed_score_mod
def _build_block_mask_direct(self) -> BlockMask:
"""Direct block mask construction for standard causal attention.
This method constructs the block mask directly using
BlockMask.from_kv_blocks which is much more efficient than the
generic create_block_mask approach.
The direct path works as follows:
1. For each query token, fetch blocks from block_table using max_seq_len
(this fetches more blocks than needed for shorter sequences)
2. Group query tokens into chunks of q_block_size
3. For each group, deduplicate the blocks using unique_static_unsorted
4. Create BlockMask using the deduplicated block indices
Over-estimation occurs when a group of q_block_size tokens contains
multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for
each sequence represented in the group, even though individual query
tokens may only need a subset of those blocks based on causal masking
and their position.
"""
page_to_block_ratio = self.kv_block_size // self.block_size
if page_to_block_ratio != 1:
raise ValueError(
f"FlexAttention currently requires the cache block size "
f"({self.block_size}) to be equal to the kv_block_size "
f"({self.kv_block_size}). Please check your model's "
f"configuration.")
used_pages = self.block_table[
self.doc_ids, :cdiv(self.max_seq_len, self.block_size)]
used_pages_padded = pad_to_multiple(used_pages,
multiple=self.q_block_size,
dim=0)
used_pages_padded = used_pages_padded.reshape(
used_pages_padded.shape[0] // self.q_block_size, -1)
used_pages_padded = used_pages_padded // page_to_block_ratio
kv_indices = unique_static_unsorted((used_pages_padded.long()),
M=self.num_blocks).to(torch.int32)
kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
block_mask_kwargs = {
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
"kv_num_blocks": kv_num_blocks[None, None],
"kv_indices": kv_indices[None, None],
"full_kv_num_blocks": None,
"full_kv_indices": None,
"BLOCK_SIZE": (self.q_block_size, self.kv_block_size),
"mask_mod": self.mask_mod,
}
# compute_q_blocks parameter is available in PyTorch 2.9+
if is_torch_equal_or_newer("2.9.0.dev0"):
block_mask_kwargs["compute_q_blocks"] = False
return BlockMask.from_kv_blocks(**block_mask_kwargs)
def build_block_mask(self) -> BlockMask:
if self.causal:
mask_mod = self.get_causal_mask_mod()
@ -267,6 +485,7 @@ class FlexAttentionMetadata:
self.num_actual_tokens,
kv_len,
device=self.block_table.device,
BLOCK_SIZE=(self.q_block_size, self.kv_block_size),
)
def __post_init__(self):
@ -275,8 +494,21 @@ class FlexAttentionMetadata:
assert self.cu_prefix_query_lens is None, "Not implemented yet."
assert self.prefix_kv_lens is None, "Not implemented yet."
assert self.suffix_kv_lens is None, "Not implemented yet."
# Create a lookup mapping from query indices -> request number
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
self.num_blocks = self.total_cache_tokens // self.block_size
self.block_mask = self.build_block_mask()
if self.causal:
self.mask_mod = self.get_causal_mask_mod()
else:
self.mask_mod = self.get_bidirectional_mask_mod()
self.transformed_score_mod = self.get_transformed_score_mod()
if self.direct_build and self.causal:
self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask()
class FlexAttentionMetadataBuilder(
@ -287,15 +519,24 @@ class FlexAttentionMetadataBuilder(
self.model_config = vllm_config.model_config
self.parallel_config = vllm_config.parallel_config
self.cache_config = vllm_config.cache_config
self.device = device
self.num_heads_q = self.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.parallel_config)
self.num_heads_kv = self.model_config.get_num_kv_heads(
vllm_config.parallel_config)
self.parallel_config)
self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.device = device
self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0")
self.q_block_size: int = 16 if is_torch_equal_or_newer(
"2.9.0.dev0") else 128
self.kv_block_size: int = 16 if is_torch_equal_or_newer(
"2.9.0.dev0") else 128
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self,
common_prefix_len: int,
@ -310,6 +551,7 @@ class FlexAttentionMetadataBuilder(
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
num_blocks_per_seq = cdiv(seq_lens, self.block_size)
use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
@ -320,12 +562,15 @@ class FlexAttentionMetadataBuilder(
block_size = self.kv_cache_spec.block_size
max_possible_seq_len = self.model_config.max_model_len
total_cache_tokens = self.cache_config.num_gpu_blocks * block_size
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None, \
"FlexAttention requires num_gpu_blocks to be set"
total_cache_tokens = (num_gpu_blocks * block_size)
inverse_block_table = physical_to_logical_mapping(
block_table_tensor, self.cache_config.num_gpu_blocks)
block_table_tensor, seq_lens, block_size, num_gpu_blocks)
# Get the original offset tensor
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
self.device, non_blocking=True)
@ -349,9 +594,16 @@ class FlexAttentionMetadataBuilder(
physical_to_logical=inverse_block_table,
total_cache_tokens=total_cache_tokens,
decode_offset=offset_tensor,
num_blocks_per_seq=num_blocks_per_seq,
direct_build=self.direct_build,
q_block_size=self.q_block_size,
kv_block_size=self.kv_block_size,
)
return out
def use_cascade_attention(self, *args, **kwargs) -> bool:
return False
class FlexAttentionImpl(AttentionImpl):
sliding_window: Optional[tuple[int, int]]
@ -370,6 +622,7 @@ class FlexAttentionImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -398,6 +651,7 @@ class FlexAttentionImpl(AttentionImpl):
raise NotImplementedError(
"FlexAttention does not support logits soft cap yet.")
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if kv_sharing_target_layer_name is not None:
@ -405,7 +659,6 @@ class FlexAttentionImpl(AttentionImpl):
"FlexAttention does not support kv sharing yet.")
FlexAttentionBackend.validate_head_size(head_size)
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlexAttention does not support quantized kv-cache. Yet")
@ -493,35 +746,48 @@ class FlexAttentionImpl(AttentionImpl):
# Doesn't work for now -> constraint violation
# torch._dynamo.try_mark_dynamic(query, 2)
# default M=64, N=64 may run out of shared memory on some GPUs
# TODO: Explicit configs for each GPU?
# Not sure how to calculate the shared memory requirement
extra_kernel_options = defaultdict[str, int](lambda: 64)
if query.dtype == torch.float32:
extra_kernel_options["BLOCK_M"] //= 2
extra_kernel_options["BLOCK_N"] //= 2
if current_platform.is_cuda():
device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin
if max_shared_memory < 144 * 1024:
extra_kernel_options["BLOCK_M"] //= 2
extra_kernel_options["BLOCK_N"] //= 2
assert attn_metadata.block_mask is not None
block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE
kernel_options = get_kernel_options(query, block_m, block_n,
attn_metadata.direct_build)
out = flex_attention_compiled(
query,
key_tensor,
value_tensor,
attn_metadata.score_mod,
attn_metadata.transformed_score_mod,
attn_metadata.block_mask,
self.scale,
enable_gqa=enable_gqa,
kernel_options={
"FORCE_USE_FLEX_ATTENTION": True,
**extra_kernel_options
},
kernel_options=kernel_options,
)
# Flex doesn't have an out variant today, rely on epilogue fusion
out = out.permute(0, 2, 1, 3).squeeze(0)
output[:num_actual_tokens, :, :].copy_(out)
return output
def get_kernel_options(query, block_m, block_n,
use_direct_build: bool) -> dict[str, Union[int, bool]]:
kernel_options: dict[str, Union[int, bool]] = {
"FORCE_USE_FLEX_ATTENTION": True,
}
if use_direct_build:
kernel_options["BLOCK_M"] = block_m
kernel_options["BLOCK_N"] = block_n
return kernel_options
else:
kernel_options["BLOCK_M"] = 64
kernel_options["BLOCK_N"] = 64
if query.dtype == torch.float32:
kernel_options["BLOCK_M"] = 32
kernel_options["BLOCK_N"] = 32
# if current_platform.is_cuda():
if torch.cuda.is_available():
device_props = torch.cuda.get_device_properties()
max_shared_memory = device_props.shared_memory_per_block_optin
if max_shared_memory < 144 * 1024:
kernel_options["BLOCK_M"] = 32
kernel_options["BLOCK_N"] = 32
return kernel_options

View File

@ -1,22 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]:
if mamba_type == "mamba1":
return Mamba1AttentionBackend
if mamba_type == "mamba2":
return Mamba2AttentionBackend
if mamba_type == "linear_attention":
return LinearAttentionBackend
if mamba_type == "short_conv":
return ShortConvAttentionBackend
raise NotImplementedError(f"Mamba Attention type {mamba_type} is not "
"supported yet.")

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import OrderedDict
from collections.abc import Mapping
from typing import TYPE_CHECKING
@ -31,34 +33,52 @@ class EncoderCacheManager:
within requests, allowing for fine-grained memory management and enabling
chunked processing of multimodal inputs.
Note that no caching is shared between requests at this time. If the same
input is used across multiple requests, it will be reprocessed for each
request.
Cache is enabled to share embeddings of same multimodal data
item (identified by their hash value) between different requests,
and eviction takes place at allocation time when there's no free
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
Args:
cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder tokens
num_free_slots: Current available cache capacity in encoder tokens
cached: Mapping from request_id to set of cached input_ids for that
request
freed: List of (request_id, input_id) pairs that were recently freed.
This is cleared after every call to get_freed_ids().
cache_size: Total cache capacity in encoder tokens.
num_free_slots: Current available cache capacity in encoder tokens.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder tokens).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
last call to get_freed_mm_hashes(). This list is cleared on return.
"""
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: dict[str, set[int]] = {}
# list of [req_id, input_id]
self.freed: list[tuple[str, int]] = []
self.num_freeable_slots = cache_size
def has_cache(self, request: Request, input_id: int) -> bool:
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
"""Check if encoder output for a specific multimodal input is cached.
If the encoder output is cached, update `cached` to add the request id
to the set of request ids that reference the cached encoder output.
If the encoder output was previously not referenced by any request,
update `freeable` and `num_freeable_slots` accordingly.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
@ -66,103 +86,151 @@ class EncoderCacheManager:
Returns:
True if the encoder output for this input is already cached
"""
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]
mm_hash = request.mm_hashes[input_id]
# Not cached at all
if mm_hash not in self.cached:
return False
def can_allocate(self, request: Request, input_id: int) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
# Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
self.cached[mm_hash].add(request.request_id)
return True
def try_allocate(self, request: Request, input_id: int,
encoder_budget: int) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
If there is not enough free space in `num_free_slots` but there is
enough reclaimable space in `num_freeable_slots`, entries will be
evicted from `freeable` (their mm_hash appended to `freed`) until
enough space is available, and then this method returns True.
Older entries are evicted first.
Returns False only if the requested number of tokens exceeds both
the free and reclaimable capacities combined.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
Returns:
True if there's enough free cache space to store the encoder output
for this multimodal input
True if there's enough capacity to hold the encoder output for this
input (possibly after reclaiming `freeable` entries); otherwise
False.
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
# Not enough compute budget
if num_tokens > encoder_budget:
return False
# Enough free slots
if num_tokens <= self.num_free_slots:
self.num_free_slots -= num_tokens
self.num_freeable_slots -= num_tokens
return True
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots -= num_tokens
self.num_freeable_slots -= num_tokens
return True
def allocate(self, request: Request, input_id: int) -> None:
"""Allocate cache space for a multimodal input's encoder output.
This method reserves cache space for storing the encoder output of
the specified multimodal input. The actual encoder output storage
happens in the model runner, but this method ensures the cache
manager tracks the allocation.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
This reserves cache space for storing the encoder output of the
specified multimodal input. The actual encoder output storage happens in
the model runner; this method updates the manager's bookkeeping.
Note:
This method assumes can_allocate() returned True for the same
request and input_id. It will reduce available cache space.
This method assumes try_allocate() returned True for the same input.
"""
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
# Encoder cache space budget should be already updated for the
# multimodal input and non-negative after try_allocate() is called.
assert self.num_free_slots >= 0
assert self.num_freeable_slots >= 0
mm_hash = request.mm_hashes[input_id]
request_id = request.request_id
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
self.cached[mm_hash].add(request_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
Args:
request: The request to query
Returns:
Set of input_ids that have cached encoder outputs for this request.
Returns empty set if no inputs are cached for this request.
Returns the set of input IDs whose `mm_hash` exists in the cache map.
This includes entries that are currently unreferenced (and thus present
in `freeable`); for such entries, freeing for this request will be a
no-op.
"""
return self.cached.get(request.request_id, set())
return {
input_id
for input_id in range(len(request.mm_hashes))
if request.mm_hashes[input_id] in self.cached
}
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free cache space for a single multimodal input's encoder output.
"""Free the request's reference to the encoder input (`mm_data`)
This method is called when:
- The encoder output has been fully consumed by the decoder and is
no longer needed (e.g., in vision-language models after image
tokens are processed)
- A request is being cancelled or aborted
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder tokens for that input.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input to free from cache
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
"""
req_id = request.request_id
if req_id not in self.cached:
mm_hash = request.mm_hashes[input_id]
# The mm_hash not in cache or the req_id set is empty
if not self.cached.get(mm_hash, None):
return
self.cached[req_id].discard(input_id)
if len(self.cached[req_id]) == 0:
del self.cached[req_id]
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))
self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]:
num_tokens = request.get_num_encoder_tokens(input_id)
self.freeable[mm_hash] = num_tokens
self.num_freeable_slots += num_tokens
def free(self, request: Request) -> None:
"""Free all cached encoder outputs for a request.
"""Free all encoder input cache reference held by *request*.
This method is typically called when a request is finished, cancelled,
or aborted, and all its encoder outputs should be freed from cache.
For each cached input ID, `free_encoder_input` is invoked.
The data stays in memory until eviction is triggered by a future
attempt allocation called by 'can_allocate'.
Args:
request: The request whose encoder outputs should be freed
Typically called when a request is finished, cancelled, or aborted.
"""
input_ids = self.get_cached_input_ids(request).copy()
for input_id in input_ids:
self.free_encoder_input(request, input_id)
def get_freed_ids(self) -> list[tuple[str, int]]:
def get_freed_mm_hashes(self) -> list[str]:
"""Get and clear the list of recently freed encoder cache entries.
This method returns all encoder cache entries that were freed since
the last call to this method. It's used by the scheduler to notify
workers about which encoder outputs can be removed from their caches.
Returns:
List of (request_id, input_id) tuples that were freed since the
last call. The internal freed list is cleared after this call.
List of mm_hash strings that were actually evicted since the last
call to be used by the scheduler to notify workers about which
encoder outputs can be removed from their caches. The internal
list is cleared after this call.
"""
freed = self.freed
self.freed = []
@ -177,16 +245,11 @@ def compute_encoder_budget(
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
- Compute budget for encoder execution, measured in number of tokens
from the input sequence.
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = mm_registry \
@ -231,10 +294,10 @@ def compute_mm_encoder_budget(
non-text modality.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
- Compute budget for encoder execution, measured in number of tokens
from the input sequence.
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if not max_tokens_by_modality:

View File

@ -143,9 +143,9 @@ class SchedulerOutput:
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: set[str]
# list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]
# list of mm_hash strings associated with the encoder outputs to be
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# Dict of request ids to their index within the batch
# for filling the next token bitmask

View File

@ -252,6 +252,7 @@ class Scheduler(SchedulerInterface):
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
@ -550,7 +551,8 @@ class Scheduler(SchedulerInterface):
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
@ -698,7 +700,7 @@ class Scheduler(SchedulerInterface):
# in the decoder's KV cache.
continue
if self.encoder_cache_manager.has_cache(request, i):
if self.encoder_cache_manager.check_and_update_cache(request, i):
# The encoder input is already computed and cached.
continue
@ -712,8 +714,8 @@ class Scheduler(SchedulerInterface):
num_new_tokens = start_pos - num_computed_tokens
break
if (not self.encoder_cache_manager.can_allocate(request, i)
or num_encoder_tokens > encoder_budget):
if not self.encoder_cache_manager.try_allocate(
request, i, encoder_budget):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
# be processed altogether, as the encoder usually uses

View File

@ -21,6 +21,8 @@ from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
validate_structured_output_request_lm_format_enforcer)
from vllm.v1.structured_output.backend_outlines import (
validate_structured_output_request_outlines)
from vllm.v1.structured_output.backend_xgrammar import (
@ -200,6 +202,9 @@ class Processor:
elif engine_level_backend == "outlines":
# outlines backend
validate_structured_output_request_outlines(params)
elif engine_level_backend == "lm-format-enforcer":
# lm format enforcer backend
validate_structured_output_request_lm_format_enforcer(params)
else:
# NOTE: engine_level_backend must be "auto" here, because we have
# checked supported_backends above.

View File

@ -108,6 +108,14 @@ class StructuredOutputManager:
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
elif backend == "lm-format-enforcer":
from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501
LMFormatEnforcerBackend)
self.backend = LMFormatEnforcerBackend(
self.vllm_config,
tokenizer=self.tokenizer,
vocab_size=vocab_size,
)
else:
raise ValueError(
f"Unsupported structured output backend: {backend}")

View File

@ -0,0 +1,167 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import ast
import json
from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING
import torch
from transformers import PreTrainedTokenizerBase
from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
if TYPE_CHECKING:
import lmformatenforcer
import lmformatenforcer.integrations.vllm as lmfe_vllm
else:
lmformatenforcer = LazyLoader("lmformatenforcer", globals(),
"lmformatenforcer")
lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(),
"lmformatenforcer.integrations.vllm")
@lru_cache
def _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer: PreTrainedTokenizerBase,
vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData:
return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data(
tokenizer, use_bitmask=True, vocab_size=vocab_size)
@dataclass
class LMFormatEnforcerGrammar(StructuredOutputGrammar):
token_enforcer: lmformatenforcer.TokenEnforcer
current_tokens_prefix: list[int] = field(default_factory=list)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
original_len = len(self.current_tokens_prefix)
for token in tokens:
if not self.token_enforcer.get_allowed_tokens(
self.current_tokens_prefix).is_token_allowed(token):
# Rollback partial updates to ensure atomicity.
del self.current_tokens_prefix[original_len:]
return False
self.current_tokens_prefix.append(token)
return True
def validate_tokens(self, tokens: list[int]) -> list[int]:
for prefix_length in range(len(tokens)):
prefix = tokens[:prefix_length]
next_token = tokens[prefix_length]
if not self.token_enforcer.get_allowed_tokens(
self.current_tokens_prefix +
prefix).is_token_allowed(next_token):
break
else:
return tokens
return tokens[:prefix_length]
def rollback(self, num_tokens: int) -> None:
self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens]
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
allowed_tokens = self.token_enforcer.get_allowed_tokens(
self.current_tokens_prefix)
bitmask[batch_index] = allowed_tokens.allowed_tokens
def is_terminated(self) -> bool:
# We are considered terminated if the prefix ends with eos_token_id
return_value = len(
self.current_tokens_prefix) > 0 and self.current_tokens_prefix[
-1] == self.token_enforcer.eos_token_id
return return_value
def reset(self):
self.current_tokens_prefix = []
@dataclass
class LMFormatEnforcerBackend(StructuredOutputBackend):
def __post_init__(self):
self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
self.tokenizer, self.vocab_size)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
character_level_parser: lmformatenforcer.CharacterLevelParser
if request_type == StructuredOutputOptions.JSON:
spec_dict = json.loads(grammar_spec)
character_level_parser = lmformatenforcer.JsonSchemaParser(
spec_dict)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
character_level_parser = lmformatenforcer.JsonSchemaParser(None)
elif request_type == StructuredOutputOptions.REGEX:
character_level_parser = lmformatenforcer.RegexParser(grammar_spec)
elif request_type == StructuredOutputOptions.CHOICE:
choices = ast.literal_eval(grammar_spec)
character_level_parser = lmformatenforcer.UnionParser(
[lmformatenforcer.StringParser(choice) for choice in choices])
else:
raise ValueError(
"Invalid request type for LM Format Enforcer backend"
f"({request_type!s})")
max_rollback_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens
if self.vllm_config.speculative_config is not None else 0)
if max_rollback_tokens > 0:
raise ValueError(
"LM Format Enforcer backend does not support speculative tokens"
)
token_enforcer = lmformatenforcer.TokenEnforcer(
tokenizer_data=self.tokenizer_data,
parser=character_level_parser,
)
return LMFormatEnforcerGrammar(token_enforcer)
def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor:
return torch.full(
(max_num_seqs, (self.vocab_size + 31) // 32),
-1,
dtype=torch.int32,
pin_memory=torch.cuda.is_available(),
)
def destroy(self):
pass
def validate_structured_output_request_lm_format_enforcer(
params: SamplingParams):
if params.guided_decoding is None:
return
gd_params = params.guided_decoding
if gd_params.regex:
return
elif gd_params.json:
if isinstance(gd_params.json, str):
try:
# make sure schema is valid json
json.loads(gd_params.json)
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
try:
json.dumps(gd_params.json)
except Exception as e:
raise ValueError(
f"Error serializing guided decoding jsonschema: {e}"
) from e
return
elif gd_params.choice:
return
elif gd_params.grammar:
raise ValueError("LM Format Enforcer guided decoding backend "
"does not support grammar specifications")

View File

@ -35,7 +35,8 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import (BatchDescriptor, DPMetadata,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
@ -54,7 +55,6 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
get_dtype_size, is_pin_memory_available, round_up,
supports_dynamo)
from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
from vllm.v1.attention.backends.utils import (
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
make_kv_sharing_fast_prefill_attention_metadata)
@ -83,8 +83,9 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache,
gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
from .utils import (AttentionGroup, CpuGpuBuffer, MultiModalBudget,
bind_kv_cache, gather_mm_placeholders,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
if TYPE_CHECKING:
@ -149,6 +150,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
parallel_config)
self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config)
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
@ -176,8 +179,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
@ -226,21 +229,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._init_device_properties()
# Persistent buffers for CUDA graphs.
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.query_start_loc = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device=self.device)
self.seq_lens = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
self.input_ids = self._make_buffer(self.max_num_tokens,
dtype=torch.int32)
self.positions = self._make_buffer(self.max_num_tokens,
dtype=torch.int64)
self.query_start_loc = self._make_buffer(self.max_num_reqs + 1,
dtype=torch.int32)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
@ -254,23 +253,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1),
dtype=torch.int64,
device=self.device)
self.mrope_positions_cpu = torch.zeros(
(3, self.max_num_tokens + 1),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
self.mrope_positions = self._make_buffer(
(3, self.max_num_tokens + 1), dtype=torch.int64)
# Only relevant for models using ALiBi (e.g, MPT)
self.use_alibi = check_use_alibi(model_config)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: Optional[IntermediateTensors] = None
self.block_tables = BlockTables(
block_sizes=[self.cache_config.block_size],
@ -289,30 +276,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.max_num_tokens),
dtype=np.int64)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.input_ids_np = self.input_ids_cpu.numpy()
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.index_mapping_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
@ -353,6 +316,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(*args,
dtype=dtype,
device=self.device,
pin_memory=self.pin_memory)
def _init_model_kwargs(self, num_tokens: int):
return {}
model_kwargs = dict[str, Any]()
@ -378,7 +347,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if len(token_type_id_requests) == 0:
return model_kwargs
seq_lens = self.seq_lens[:num_reqs]
seq_lens = self.seq_lens.gpu[:num_reqs]
token_type_ids = []
for i in range(num_reqs):
@ -422,12 +391,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
req_indices: list[int] = []
cu_num_new_blocks: list[list[int]] = [
@ -617,10 +582,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
token_ids=self.requests.token_ids.np,
num_computed_tokens=self.requests.num_computed_tokens.np,
num_scheduled_tokens=num_scheduled_tokens,
input_ids=self.input_ids_np,
query_start_loc=self.query_start_loc_np,
seq_lens=self.seq_lens_np,
positions=self.positions_np,
input_ids=self.input_ids.np,
query_start_loc=self.query_start_loc.np,
seq_lens=self.seq_lens.np,
positions=self.positions.np,
)
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
@ -628,24 +593,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self._calc_mrope_positions(scheduler_output)
# Prepare the attention metadata.
self.query_start_loc.copy_(self.query_start_loc_cpu, non_blocking=True)
query_start_loc = self.query_start_loc[:num_reqs + 1]
self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[:num_reqs + 1]
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
max_seq_len = self.seq_lens_np[:num_reqs].max().item()
self.seq_lens.copy_to_gpu()
seq_lens = self.seq_lens.gpu[:num_reqs]
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
# Common case (1D positions)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
self.mrope_positions.gpu[:, :total_num_scheduled_tokens].copy_(
self.mrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
@ -698,8 +662,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc, self.positions[:total_num_scheduled_tokens])
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc_cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
num_computed_tokens_cpu = (
self.requests.num_computed_tokens.cpu[:num_reqs])
spec_decode_common_attn_metadata = None
@ -936,9 +900,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
src_start = num_computed_tokens
src_end = num_computed_tokens + prompt_part_len
self.mrope_positions_cpu[:, dst_start:dst_end] = \
req.mrope_positions[:,src_start:src_end]
self.mrope_positions.cpu[:, dst_start:dst_end] = (
req.mrope_positions[:, src_start:src_end])
mrope_pos_ptr += prompt_part_len
if completion_part_len > 0:
@ -947,7 +910,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dst_end = mrope_pos_ptr + completion_part_len
MRotaryEmbedding.get_next_input_positions_tensor(
out=self.mrope_positions_np,
out=self.mrope_positions.np,
out_offset=dst_start,
mrope_position_delta=req.mrope_position_delta,
context_len=num_computed_tokens + prompt_part_len,
@ -1011,7 +974,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
draft_token_ids = self.input_ids[logits_indices]
draft_token_ids = self.input_ids.gpu[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata(
@ -1028,17 +991,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@ -1071,15 +1035,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for output in curr_group_outputs:
encoder_outputs.append(output)
# Cache the encoder outputs.
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
@ -1098,6 +1056,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
shift_computed_tokens)
req_data = self.requests.req_data[req_idx]
mm_positions = req_data.mm_positions
mm_hashes = req_data.mm_hashes
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
@ -1117,11 +1076,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
num_encoder_tokens,
)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
mm_hash = mm_hashes[i]
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
@ -1343,7 +1305,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pooling_metadata = self.input_batch.pooling_metadata
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
device=hidden_states.device)
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
# Pooling models D2H & synchronize occurs in pooler.py:build_output
raw_pooler_output = self.model.pooler(
@ -1420,7 +1382,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids[:num_scheduled_tokens],
input_ids=self.input_ids.gpu[:num_scheduled_tokens],
multimodal_embeddings=mm_embeds or None,
)
@ -1439,13 +1401,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
input_ids = self.input_ids.gpu[:num_input_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_input_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
positions = self.mrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions[:num_input_tokens]
positions = self.positions.gpu[:num_input_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@ -1694,9 +1656,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if input_batch.spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[:num_scheduled_tokens]
target_positions = self.positions.gpu[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
@ -1716,9 +1678,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.drafter.prepare_inputs(
common_attn_metadata, num_rejected_tokens_cpu)
target_token_ids = self.input_ids[token_indices]
target_token_ids = self.input_ids.gpu[token_indices]
# TODO(woosuk): Support M-RoPE.
target_positions = self.positions[token_indices]
target_positions = self.positions.gpu[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
@ -1959,8 +1921,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# then there is prompt logprob generated for each index.
req_idx = 0
offset = self.query_start_loc_np[req_idx].item()
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc.np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
@ -2032,7 +1994,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
@functools.cache
def rand_input_ids() -> torch.Tensor:
return torch.randint_like(
self.input_ids,
self.input_ids.gpu,
low=0,
high=self.model_config.get_vocab_size(),
dtype=input_ids.dtype)
@ -2149,18 +2111,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata = {}
# Make sure max_model_len is used at the graph capture time.
self.seq_lens_np[:num_reqs] = self.max_model_len
self.seq_lens_np[num_reqs:] = 0
self.seq_lens.copy_(self.seq_lens_cpu, non_blocking=True)
self.seq_lens.np[:num_reqs] = self.max_model_len
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs +
query_start_loc=self.query_start_loc.gpu[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc.cpu[:num_reqs +
1],
seq_lens=self.seq_lens[:num_reqs],
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
seq_lens=self.seq_lens.gpu[:num_reqs],
seq_lens_cpu=self.seq_lens.cpu[:num_reqs],
num_computed_tokens_cpu=self.requests.num_computed_tokens.
cpu[:num_reqs],
num_reqs=num_reqs,
@ -2190,14 +2152,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**self._dummy_mm_kwargs(num_reqs),
}
else:
input_ids = self.input_ids[:num_tokens]
input_ids = self.input_ids.gpu[:num_tokens]
inputs_embeds = None
model_kwargs = self._init_model_kwargs(num_tokens)
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
positions = self.mrope_positions.gpu[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
positions = self.positions.gpu[:num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@ -2583,11 +2545,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"""
assert len(self.attn_groups) == 0, \
"Attention backends are already initialized"
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
def get_attn_backends_for_layers(
layer_names: list[str]
) -> dict[type[AttentionBackend], list[str]]:
layers = get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase,
layer_names)
attn_backends = {}
attn_backend_layers = defaultdict(list)
# Dedupe based on full class name; this is a bit safer than using
@ -2596,7 +2560,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# they are cached correctly, there will be different objects per
# layer.
for layer_name in layer_names:
attn_backend = attn_layers[layer_name].get_attn_backend()
attn_backend = layers[layer_name].get_attn_backend()
key = attn_backend.full_cls_name()
attn_backends[key] = attn_backend
attn_backend_layers[key].append(layer_name)
@ -2625,20 +2589,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if isinstance(kv_cache_spec, AttentionSpec):
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
# TODO(lucas): move `get_mamba_attn_backend` into the mamba
# layers like above
elif isinstance(kv_cache_spec, MambaSpec):
attn_backends = {
get_mamba_attn_backend(kv_cache_spec.mamba_type):
kv_cache_group_spec.layer_names
}
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
attn_backends = get_attn_backends_for_layers(
kv_cache_group_spec.layer_names)
self.attn_groups.append(
create_attn_groups(attn_backends, kv_cache_spec))

View File

@ -208,8 +208,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Lazy initialization
self.model: nn.Module # Set after load_model
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
# Request states.
self.requests: dict[str, CachedRequestState] = {}
@ -342,7 +342,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@ -357,12 +356,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
removed_req_indices.append(req_index)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
@ -394,6 +389,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
mm_hashes=new_req_data.mm_hashes,
sampling_params=sampling_params,
pooling_params=None,
generator=None,
@ -845,14 +841,16 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
# List of tuple (mm_hash, pos_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@ -895,15 +893,15 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE (NickLucche) here we diverge from logic in other runners, as we
# assume to only have whole mm items to process. Hence we avoid the
# intrinsic dynamism that `scatter_mm_placeholders` introduces.
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
for (mm_hash, pos_info), output in zip(
mm_hashes_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
assert pos_info.is_embed is None, "Expected all positions to be"\
" contiguous and embeddings."
self.encoder_cache[req_id][input_id] = output
self.encoder_cache[mm_hash] = output
def _gather_mm_embeddings(
self,
@ -916,6 +914,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid
@ -936,11 +935,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the decoder's KV cache.
continue
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
mm_hash = mm_hashes[i]
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
assert pos_info.is_embed is None, "Expected all positions to"\
" be contiguous and embeddings."
encoder_output = self.encoder_cache[req_id][i]
encoder_output = self.encoder_cache[mm_hash]
mm_embeds.append(encoder_output)
return mm_embeds

View File

@ -319,11 +319,11 @@ class CpuGpuBuffer:
def copy_to_gpu(self, n: Optional[int] = None) -> None:
if n is None:
return self.gpu.copy_(self.cpu, non_blocking=True)
else:
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True)
def copy_to_cpu(self, n: Optional[int] = None) -> None:
"""NOTE: Because this method is non-blocking, explicit synchronization
is needed to ensure the data is copied to CPU."""
if n is None:
return self.cpu.copy_(self.gpu, non_blocking=True)
else:
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)