[Core][MM] Optimize encoder cache manager by operating with embeddings only (#30475)

Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Sun Kim <sunytokki@gmail.com>
This commit is contained in:
Roger Wang 2025-12-16 14:18:17 -08:00 committed by GitHub
parent 9fec0e13d5
commit f5f51e5931
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 306 additions and 130 deletions

View File

@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
total_num_patches.item() + num_tiles.item() + 3
) # image start, image, image end
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
profiled_tokens = profiler.get_mm_max_tokens(
max_model_len,
mm_counts=mm_counts,
)
assert total_tokens == profiled_tokens["image"]
assert total_num_patches == profiled_tokens["image"]
assert total_tokens == sum(
placeholder.length
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]

View File

@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
import numpy as np
import pytest
import torch
from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode
@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
assert modality_idxs == expected_modality_idxs
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, 5),
(torch.tensor([True, True, True, True, True]), 5),
(torch.tensor([False, False, False, False, False]), 0),
(torch.tensor([True, False, True, False, True]), 3),
(torch.tensor([True]), 1),
],
)
def test_placeholder_range_get_num_embeds(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_num_embeds == expected
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, None),
(
torch.tensor([False, True, False, True, True]),
torch.tensor([0, 1, 1, 2, 3]),
),
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
],
)
def test_placeholder_range_embeds_cumsum(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
if expected is None:
assert pr.embeds_cumsum is None
return
assert torch.equal(pr.embeds_cumsum, expected)
# cached_property should return the same object on repeated access
assert pr.embeds_cumsum is pr.embeds_cumsum
@pytest.mark.parametrize(
"is_embed,start_idx,end_idx,expected",
[
(None, 2, 4, (2, 4)),
(
torch.tensor([False, True, False, True, True]),
3,
5,
(1, 3),
),
(
torch.tensor([False, True, False, True, True]),
0,
2,
(0, 1),
),
(
torch.tensor([True, False, True, False]),
2,
2,
(1, 1),
),
],
)
def test_placeholder_range_get_embeds_indices_in_range(
is_embed, start_idx, end_idx, expected
):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
@pytest.mark.parametrize(
"offset,is_embed,expected",
[
(0, None, [(0, 4)]),
(
2,
torch.tensor([False, True, False, True, True]),
[(3, 3), (5, 6)],
),
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
(0, torch.tensor([False, False, False, False]), []),
],
)
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
assert pr.extract_embeds_range() == expected
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
@ -23,7 +24,7 @@ class MockRequest:
)
self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
return self._token_counts[input_id]
@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
compute_budget = 10
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
def test_encoder_cache_with_is_embed_mask():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.zeros(100, dtype=torch.bool)
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
request = MockRequestWithMask("r1", ["img1"], [100])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=100)
manager.allocate(request, 0)
assert manager.num_free_slots == 92
assert "img1" in manager.cached
old_size = 100
new_size = request.mm_features[0].mm_position.get_num_embeds
assert new_size == 8
savings_ratio = old_size / new_size
assert savings_ratio == 12.5
def test_encoder_cache_mask_based_retrieval():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.tensor(
[False, False, True, True, False, True, True, True, False, False]
)
request = MockRequestWithMask("r1", ["img1"], [10])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=50)
manager.allocate(request, 0)
assert request.mm_features[0].mm_position.get_num_embeds == 5
start_idx = 2
end_idx = 8
num_embeds_before = is_embed[:start_idx].sum().item()
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 5
start_idx = 0
end_idx = 5
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 2

View File

@ -38,7 +38,7 @@ class MockRequest:
)
self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self._token_counts)
return self._token_counts[input_id]

View File

@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
Update ECConnector state after encoder cache allocation.
"""
mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index)
num_encoder_token = request.get_num_encoder_embeds(index)
# Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token

View File

@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
num_video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
return num_video_soft_tokens
def _calculate_timestamps(
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from functools import cached_property, partial
from itertools import accumulate
from typing import (
TYPE_CHECKING,
@ -169,11 +169,42 @@ class PlaceholderRange:
between `offset` and `offset + length` to assign embeddings to.
"""
def get_num_embeds(self) -> int:
@cached_property
def embeds_cumsum(self) -> torch.Tensor | None:
if self.is_embed is None:
return None
return self.is_embed.cumsum(dim=0)
@cached_property
def get_num_embeds(self) -> int:
if self.embeds_cumsum is None:
return self.length
return int(self.is_embed.sum().item())
return int(self.embeds_cumsum[-1])
def get_embeds_indices_in_range(
self, start_idx: int, end_idx: int
) -> tuple[int, int]:
"""
Returns the starting and ending indices of the embeddings of encoder outputs
in the range of [start_idx, end_idx) in the placeholders.
For example, given:
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
the second and the third embeddings from the encoder output.
"""
if self.embeds_cumsum is None:
return start_idx, end_idx
embeds_start_idx = (
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
)
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
return embeds_start_idx, embeds_end_idx
def extract_embeds_range(self) -> list[tuple[int, int]]:
"""Extract the start and end indices of the embedded region in prompt.
@ -188,7 +219,7 @@ class PlaceholderRange:
Returns full placeholder range if `is_embed` is `None`.
"""
if self.is_embed is None:
return [(self.offset, self.offset + self.length)]
return [(self.offset, self.offset + self.length - 1)]
mask_i = self.is_embed.int()
starts = torch.nonzero(

View File

@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
return {
modality: sum(
item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders
)
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def _get_mm_max_tokens(
def get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
"""
Returns the maximum number of embeddings per item of each modality, excluding
any break/text tokens in-between multimodal embeddings/encoder outputs.
"""
if mm_counts is None:
mm_counts = self.get_mm_limits()
@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
return self._get_mm_num_tokens(mm_inputs)

View File

@ -164,7 +164,7 @@ class MultiModalRegistry:
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens(
return profiler.get_mm_max_tokens(
seq_len,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)

View File

@ -39,20 +39,26 @@ class EncoderCacheManager:
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
instead of encoder tokens (i.e. all tokens that represent the multimodal data
in the input sequence). This means all break/text tokens in-between multimodal
embeddings are not considered with respect to the cache size and the number
of free slots.
Args:
cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence.
encoder embeddings from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder tokens.
num_free_slots: Current available cache capacity in encoder tokens.
cache_size: Total cache capacity in encoder embeddings.
num_free_slots: Current available cache capacity in encoder embeddings.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder tokens).
evicting entries with zero references (in encoder embeddings).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
@ -67,7 +73,7 @@ class EncoderCacheManager:
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
# mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
@ -93,8 +99,8 @@ class EncoderCacheManager:
# Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
num_encoder_embeds = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_encoder_embeds
self.cached[mm_hash].add(request.request_id)
return True
@ -104,7 +110,7 @@ class EncoderCacheManager:
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
@ -121,9 +127,9 @@ class EncoderCacheManager:
Args:
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder tokens allowed to be
encoder_compute_budget: Number of encoder embeddings allowed to be
computed when this method is invoked.
num_tokens_to_schedule: Number of tokens already scheduled to be
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
allocated with cache space when this method is invoked.
Returns:
@ -134,30 +140,30 @@ class EncoderCacheManager:
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_tokens = request.get_num_encoder_tokens(input_id)
num_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_embeds += num_embeds_to_schedule
# Enough free slots
if num_tokens <= self.num_free_slots:
if num_embeds <= self.num_free_slots:
return True
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
if num_embeds > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False)
while num_embeds > self.num_free_slots:
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots += num_free_embeds
return True
def allocate(self, request: Request, input_id: int) -> None:
@ -176,16 +182,16 @@ class EncoderCacheManager:
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate().
assert self.num_free_slots >= num_encoder_tokens
assert self.num_freeable_slots >= num_encoder_tokens
assert self.num_free_slots >= num_encoder_embeds
assert self.num_freeable_slots >= num_encoder_embeds
self.cached[mm_hash].add(request_id)
self.num_free_slots -= num_encoder_tokens
self.num_freeable_slots -= num_encoder_tokens
self.num_free_slots -= num_encoder_embeds
self.num_freeable_slots -= num_encoder_embeds
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
@ -206,7 +212,7 @@ class EncoderCacheManager:
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder tokens for that input.
increased by the number of encoder embeddings for that input.
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
@ -218,9 +224,9 @@ class EncoderCacheManager:
return
self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]:
num_tokens = request.get_num_encoder_tokens(input_id)
self.freeable[mm_hash] = num_tokens
self.num_freeable_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.freeable[mm_hash] = num_encoder_embeds
self.num_freeable_slots += num_encoder_embeds
def free(self, request: Request) -> None:
"""Free all encoder input cache reference held by *request*.
@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_encoder_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_encoder_embeds += num_embeds_to_schedule
# Enough free slots
return num_tokens <= self.num_free_slots
return num_encoder_embeds <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots -= num_encoder_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots -= num_encoder_embeds
mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash)
@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
return freed
def free_encoder_input(self, request: Request, input_id: int) -> None:
num_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots += num_encoder_embeds

View File

@ -355,11 +355,11 @@ class Scheduler(SchedulerInterface):
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_tokens_to_restore = sum(
preempted_req.get_num_encoder_tokens(i)
num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i)
for i in preempted_encoder_inputs
)
encoder_compute_budget += num_tokens_to_restore
encoder_compute_budget += num_embeds_to_restore
req_index -= 1
else:
preempted_req = self.running.pop()
@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set()
num_tokens_to_schedule = 0
num_embeds_to_schedule = 0
for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
):
num_new_tokens = start_pos - num_computed_tokens
break
if not self.encoder_cache_manager.can_allocate(
request, i, encoder_compute_budget, num_tokens_to_schedule
request, i, encoder_compute_budget, num_embeds_to_schedule
):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
# Calculate the number of embeddings to schedule in the current range
# of scheduled encoder placholder tokens.
start_idx_rel = max(0, num_computed_tokens - start_pos)
end_idx_rel = min(
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
)
curr_embeds_start, curr_embeds_end = (
mm_feature.mm_position.get_embeds_indices_in_range(
start_idx_rel,
end_idx_rel,
)
)
# There's no embeddings in the current range of encoder placeholder tokens
# so we can skip the encoder input.
if curr_embeds_end - curr_embeds_start == 0:
continue
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
continue
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
encoder_inputs_to_schedule.append(i)

View File

@ -209,10 +209,10 @@ class Request:
def get_finished_reason(self) -> FinishReason | None:
return RequestStatus.get_finished_reason(self.status)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self.mm_features)
num_tokens = self.mm_features[input_id].mm_position.length
return num_tokens
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
return num_embeds
def record_event(
self,

View File

@ -169,9 +169,7 @@ from .utils import (
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders,
)
if TYPE_CHECKING:
@ -2209,10 +2207,7 @@ class GPUModelRunner(
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
self.encoder_cache[mm_hash] = output
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
@ -2263,6 +2258,13 @@ class GPUModelRunner(
num_encoder_tokens,
)
assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
)
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if curr_embeds_start == curr_embeds_end:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
@ -2270,16 +2272,14 @@ class GPUModelRunner(
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope:
@ -4508,31 +4508,8 @@ class GPUModelRunner(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
# (max_tokens_for_modality, hidden_size) and scatter
# encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
dummy_modality
]
if encoder_output_shape[0] < max_mm_tokens_per_item:
encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
(max_mm_tokens_per_item, encoder_hidden_size)
)
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
expanded_outputs.append(expanded)
dummy_encoder_outputs = expanded_outputs
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
for i, output in enumerate(dummy_encoder_outputs):
self.encoder_cache[f"tmp_{i}"] = output
# Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states = self._dummy_run(

View File

@ -4,10 +4,12 @@ from collections import defaultdict
from dataclasses import dataclass, field
import torch
from typing_extensions import deprecated
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
logger = init_logger(__name__)
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
)
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: torch.Tensor | None,
@ -226,6 +231,7 @@ def scatter_mm_placeholders(
return placeholders
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: torch.Tensor | None,