mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 09:57:09 +08:00
[Core][MM] Optimize encoder cache manager by operating with embeddings only (#30475)
Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Sun Kim <sunytokki@gmail.com>
This commit is contained in:
parent
9fec0e13d5
commit
f5f51e5931
@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
total_num_patches.item() + num_tiles.item() + 3
|
||||
) # image start, image, image end
|
||||
|
||||
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
|
||||
profiled_tokens = profiler.get_mm_max_tokens(
|
||||
max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
|
||||
assert total_tokens == profiled_tokens["image"]
|
||||
assert total_num_patches == profiled_tokens["image"]
|
||||
assert total_tokens == sum(
|
||||
placeholder.length
|
||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||
|
||||
@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, 5),
|
||||
(torch.tensor([True, True, True, True, True]), 5),
|
||||
(torch.tensor([False, False, False, False, False]), 0),
|
||||
(torch.tensor([True, False, True, False, True]), 3),
|
||||
(torch.tensor([True]), 1),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_num_embeds(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_num_embeds == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, None),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
torch.tensor([0, 1, 1, 2, 3]),
|
||||
),
|
||||
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_embeds_cumsum(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
|
||||
if expected is None:
|
||||
assert pr.embeds_cumsum is None
|
||||
return
|
||||
|
||||
assert torch.equal(pr.embeds_cumsum, expected)
|
||||
# cached_property should return the same object on repeated access
|
||||
assert pr.embeds_cumsum is pr.embeds_cumsum
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,start_idx,end_idx,expected",
|
||||
[
|
||||
(None, 2, 4, (2, 4)),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
3,
|
||||
5,
|
||||
(1, 3),
|
||||
),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
0,
|
||||
2,
|
||||
(0, 1),
|
||||
),
|
||||
(
|
||||
torch.tensor([True, False, True, False]),
|
||||
2,
|
||||
2,
|
||||
(1, 1),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_embeds_indices_in_range(
|
||||
is_embed, start_idx, end_idx, expected
|
||||
):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offset,is_embed,expected",
|
||||
[
|
||||
(0, None, [(0, 4)]),
|
||||
(
|
||||
2,
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
[(3, 3), (5, 6)],
|
||||
),
|
||||
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
|
||||
(0, torch.tensor([False, False, False, False]), []),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
|
||||
assert pr.extract_embeds_range() == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
@ -23,7 +24,7 @@ class MockRequest:
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
|
||||
|
||||
num_tokens_to_schedule = 0
|
||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
||||
compute_budget -= req.get_num_encoder_tokens(0)
|
||||
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||
compute_budget -= req.get_num_encoder_embeds(0)
|
||||
|
||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||
|
||||
@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
|
||||
compute_budget = 10
|
||||
num_tokens_to_schedule = 0
|
||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
||||
compute_budget -= req.get_num_encoder_tokens(0)
|
||||
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||
compute_budget -= req.get_num_encoder_embeds(0)
|
||||
|
||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||
|
||||
|
||||
def test_encoder_cache_with_is_embed_mask():
|
||||
class MockRequestWithMask(MockRequest):
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||
|
||||
is_embed = torch.zeros(100, dtype=torch.bool)
|
||||
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
|
||||
|
||||
request = MockRequestWithMask("r1", ["img1"], [100])
|
||||
request.mm_features[0] = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier="img1",
|
||||
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
|
||||
)
|
||||
|
||||
manager = EncoderCacheManager(cache_size=100)
|
||||
manager.allocate(request, 0)
|
||||
|
||||
assert manager.num_free_slots == 92
|
||||
assert "img1" in manager.cached
|
||||
|
||||
old_size = 100
|
||||
new_size = request.mm_features[0].mm_position.get_num_embeds
|
||||
assert new_size == 8
|
||||
savings_ratio = old_size / new_size
|
||||
assert savings_ratio == 12.5
|
||||
|
||||
|
||||
def test_encoder_cache_mask_based_retrieval():
|
||||
class MockRequestWithMask(MockRequest):
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||
|
||||
is_embed = torch.tensor(
|
||||
[False, False, True, True, False, True, True, True, False, False]
|
||||
)
|
||||
|
||||
request = MockRequestWithMask("r1", ["img1"], [10])
|
||||
request.mm_features[0] = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier="img1",
|
||||
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
|
||||
)
|
||||
|
||||
manager = EncoderCacheManager(cache_size=50)
|
||||
manager.allocate(request, 0)
|
||||
|
||||
assert request.mm_features[0].mm_position.get_num_embeds == 5
|
||||
|
||||
start_idx = 2
|
||||
end_idx = 8
|
||||
num_embeds_before = is_embed[:start_idx].sum().item()
|
||||
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 5
|
||||
|
||||
start_idx = 0
|
||||
end_idx = 5
|
||||
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
|
||||
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 2
|
||||
|
||||
@ -38,7 +38,7 @@ class MockRequest:
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
assert input_id < len(self._token_counts)
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
|
||||
Update ECConnector state after encoder cache allocation.
|
||||
"""
|
||||
mm_hash = request.mm_features[index].identifier
|
||||
num_encoder_token = request.get_num_encoder_tokens(index)
|
||||
num_encoder_token = request.get_num_encoder_embeds(index)
|
||||
# Insert mm_hash only if this block has not been recorded yet.
|
||||
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
||||
|
||||
|
||||
@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
video_soft_tokens = self.get_num_video_tokens(
|
||||
num_video_soft_tokens = self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
# NOTE: By default in Qwen3-VL, one video token is converted to
|
||||
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
|
||||
formatted_video_soft_tokens = video_soft_tokens * 12.5
|
||||
return int(formatted_video_soft_tokens)
|
||||
return num_video_soft_tokens
|
||||
|
||||
def _calculate_timestamps(
|
||||
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
|
||||
|
||||
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from functools import cached_property, partial
|
||||
from itertools import accumulate
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@ -169,11 +169,42 @@ class PlaceholderRange:
|
||||
between `offset` and `offset + length` to assign embeddings to.
|
||||
"""
|
||||
|
||||
def get_num_embeds(self) -> int:
|
||||
@cached_property
|
||||
def embeds_cumsum(self) -> torch.Tensor | None:
|
||||
if self.is_embed is None:
|
||||
return None
|
||||
|
||||
return self.is_embed.cumsum(dim=0)
|
||||
|
||||
@cached_property
|
||||
def get_num_embeds(self) -> int:
|
||||
if self.embeds_cumsum is None:
|
||||
return self.length
|
||||
|
||||
return int(self.is_embed.sum().item())
|
||||
return int(self.embeds_cumsum[-1])
|
||||
|
||||
def get_embeds_indices_in_range(
|
||||
self, start_idx: int, end_idx: int
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the starting and ending indices of the embeddings of encoder outputs
|
||||
in the range of [start_idx, end_idx) in the placeholders.
|
||||
|
||||
For example, given:
|
||||
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
|
||||
|
||||
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
|
||||
the second and the third embeddings from the encoder output.
|
||||
"""
|
||||
if self.embeds_cumsum is None:
|
||||
return start_idx, end_idx
|
||||
|
||||
embeds_start_idx = (
|
||||
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
|
||||
)
|
||||
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
|
||||
|
||||
return embeds_start_idx, embeds_end_idx
|
||||
|
||||
def extract_embeds_range(self) -> list[tuple[int, int]]:
|
||||
"""Extract the start and end indices of the embedded region in prompt.
|
||||
@ -188,7 +219,7 @@ class PlaceholderRange:
|
||||
Returns full placeholder range if `is_embed` is `None`.
|
||||
"""
|
||||
if self.is_embed is None:
|
||||
return [(self.offset, self.offset + self.length)]
|
||||
return [(self.offset, self.offset + self.length - 1)]
|
||||
|
||||
mask_i = self.is_embed.int()
|
||||
starts = torch.nonzero(
|
||||
|
||||
@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality: sum(
|
||||
item.get_num_embeds() if mm_embeddings_only else item.length
|
||||
for item in placeholders
|
||||
)
|
||||
modality: sum(item.get_num_embeds for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def _get_mm_max_tokens(
|
||||
def get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum number of embeddings per item of each modality, excluding
|
||||
any break/text tokens in-between multimodal embeddings/encoder outputs.
|
||||
"""
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
|
||||
}
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
|
||||
|
||||
def get_mm_max_contiguous_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum length of the multimodal (image placeholders+text)
|
||||
tokens, including any break/text tokens in-between image embeddings.
|
||||
|
||||
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
|
||||
Returns 9, even when the number of image embeddings is 6.
|
||||
|
||||
This is important to take into account when profiling and
|
||||
initializing the encoder cache size.
|
||||
"""
|
||||
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
|
||||
return self._get_mm_num_tokens(mm_inputs)
|
||||
|
||||
@ -164,7 +164,7 @@ class MultiModalRegistry:
|
||||
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||
)
|
||||
|
||||
return profiler.get_mm_max_contiguous_tokens(
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||
)
|
||||
|
||||
@ -39,20 +39,26 @@ class EncoderCacheManager:
|
||||
space for new embeddings.
|
||||
Oldest cached embeddings with no request referenced will be first evicted.
|
||||
|
||||
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
|
||||
instead of encoder tokens (i.e. all tokens that represent the multimodal data
|
||||
in the input sequence). This means all break/text tokens in-between multimodal
|
||||
embeddings are not considered with respect to the cache size and the number
|
||||
of free slots.
|
||||
|
||||
Args:
|
||||
cache_size: Limit the size of the cache, measured by the number of
|
||||
tokens from the input sequence.
|
||||
encoder embeddings from the input sequence.
|
||||
|
||||
Attributes:
|
||||
cache_size: Total cache capacity in encoder tokens.
|
||||
num_free_slots: Current available cache capacity in encoder tokens.
|
||||
cache_size: Total cache capacity in encoder embeddings.
|
||||
num_free_slots: Current available cache capacity in encoder embeddings.
|
||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||
evicting entries with zero references (in encoder tokens).
|
||||
evicting entries with zero references (in encoder embeddings).
|
||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||
reference the cached entry. If the set is empty, the entry exists
|
||||
but is not referenced by any request and is eligible for
|
||||
reclamation.
|
||||
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
||||
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
|
||||
whose no current running request is needed and that can be freed to
|
||||
make space when needed.
|
||||
freed: List of mm_hash strings that were actually evicted since the
|
||||
@ -67,7 +73,7 @@ class EncoderCacheManager:
|
||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||
self.cached: dict[str, set[str]] = {}
|
||||
|
||||
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
||||
# mm_hash of mm_data => num_encoder_embeds of the mm_data
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
self.freed: list[str] = []
|
||||
|
||||
@ -93,8 +99,8 @@ class EncoderCacheManager:
|
||||
|
||||
# Cached but currently not referenced by any request
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_tokens
|
||||
num_encoder_embeds = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request.request_id)
|
||||
return True
|
||||
@ -104,7 +110,7 @@ class EncoderCacheManager:
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
If there is, return True and update EncoderCacheManager state.
|
||||
@ -121,9 +127,9 @@ class EncoderCacheManager:
|
||||
Args:
|
||||
request: The request containing the multimodal input.
|
||||
input_id: Index of the multimodal input within the request.
|
||||
encoder_compute_budget: Number of encoder tokens allowed to be
|
||||
encoder_compute_budget: Number of encoder embeddings allowed to be
|
||||
computed when this method is invoked.
|
||||
num_tokens_to_schedule: Number of tokens already scheduled to be
|
||||
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
|
||||
allocated with cache space when this method is invoked.
|
||||
|
||||
Returns:
|
||||
@ -134,30 +140,30 @@ class EncoderCacheManager:
|
||||
Note: This method does not allocate physical memory for the encoder
|
||||
output but only the state of EncoderCacheManager.
|
||||
"""
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
if num_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
num_embeds += num_embeds_to_schedule
|
||||
|
||||
# Enough free slots
|
||||
if num_tokens <= self.num_free_slots:
|
||||
if num_embeds <= self.num_free_slots:
|
||||
return True
|
||||
|
||||
# Not enough reclaimable slots
|
||||
if num_tokens > self.num_freeable_slots:
|
||||
if num_embeds > self.num_freeable_slots:
|
||||
return False
|
||||
|
||||
# Not enough free slots but enough reclaimable slots
|
||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||
# until model runner is notified by the scheduler output.
|
||||
while num_tokens > self.num_free_slots:
|
||||
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
||||
while num_embeds > self.num_free_slots:
|
||||
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
|
||||
del self.cached[mm_hash]
|
||||
self.freed.append(mm_hash)
|
||||
self.num_free_slots += num_free_token
|
||||
self.num_free_slots += num_free_embeds
|
||||
return True
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
@ -176,16 +182,16 @@ class EncoderCacheManager:
|
||||
if mm_hash not in self.cached:
|
||||
self.cached[mm_hash] = set()
|
||||
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# NOTE: Encoder cache should always have enough space for encoder inputs
|
||||
# that are scheduled since eviction takes place at can_allocate().
|
||||
assert self.num_free_slots >= num_encoder_tokens
|
||||
assert self.num_freeable_slots >= num_encoder_tokens
|
||||
assert self.num_free_slots >= num_encoder_embeds
|
||||
assert self.num_freeable_slots >= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
self.num_freeable_slots -= num_encoder_tokens
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
@ -206,7 +212,7 @@ class EncoderCacheManager:
|
||||
|
||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||
increased by the number of encoder tokens for that input.
|
||||
increased by the number of encoder embeddings for that input.
|
||||
|
||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||
`can_allocate`).
|
||||
@ -218,9 +224,9 @@ class EncoderCacheManager:
|
||||
return
|
||||
self.cached[mm_hash].discard(req_id)
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.freeable[mm_hash] = num_tokens
|
||||
self.num_freeable_slots += num_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.freeable[mm_hash] = num_encoder_embeds
|
||||
self.num_freeable_slots += num_encoder_embeds
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all encoder input cache reference held by *request*.
|
||||
@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
if num_encoder_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
num_encoder_embeds += num_embeds_to_schedule
|
||||
# Enough free slots
|
||||
return num_tokens <= self.num_free_slots
|
||||
return num_encoder_embeds <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
self.freed.append(mm_hash)
|
||||
@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
return freed
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots += num_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots += num_encoder_embeds
|
||||
|
||||
@ -355,11 +355,11 @@ class Scheduler(SchedulerInterface):
|
||||
if preempted_encoder_inputs:
|
||||
# Restore encoder compute budget if the preempted
|
||||
# request had encoder inputs scheduled in this step.
|
||||
num_tokens_to_restore = sum(
|
||||
preempted_req.get_num_encoder_tokens(i)
|
||||
num_embeds_to_restore = sum(
|
||||
preempted_req.get_num_encoder_embeds(i)
|
||||
for i in preempted_encoder_inputs
|
||||
)
|
||||
encoder_compute_budget += num_tokens_to_restore
|
||||
encoder_compute_budget += num_embeds_to_restore
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
mm_hashes_to_schedule = set()
|
||||
num_tokens_to_schedule = 0
|
||||
num_embeds_to_schedule = 0
|
||||
for i, mm_feature in enumerate(mm_features):
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
num_encoder_tokens = mm_feature.mm_position.length
|
||||
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
|
||||
):
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
break
|
||||
|
||||
if not self.encoder_cache_manager.can_allocate(
|
||||
request, i, encoder_compute_budget, num_tokens_to_schedule
|
||||
request, i, encoder_compute_budget, num_embeds_to_schedule
|
||||
):
|
||||
# The encoder cache is full or the encoder budget is exhausted.
|
||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||
@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
# Calculate the number of embeddings to schedule in the current range
|
||||
# of scheduled encoder placholder tokens.
|
||||
start_idx_rel = max(0, num_computed_tokens - start_pos)
|
||||
end_idx_rel = min(
|
||||
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
|
||||
)
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
mm_feature.mm_position.get_embeds_indices_in_range(
|
||||
start_idx_rel,
|
||||
end_idx_rel,
|
||||
)
|
||||
)
|
||||
# There's no embeddings in the current range of encoder placeholder tokens
|
||||
# so we can skip the encoder input.
|
||||
if curr_embeds_end - curr_embeds_start == 0:
|
||||
continue
|
||||
|
||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
external_load_encoder_input.append(i)
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
continue
|
||||
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
encoder_compute_budget -= num_encoder_tokens
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
encoder_compute_budget -= num_encoder_embeds
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
encoder_inputs_to_schedule.append(i)
|
||||
|
||||
|
||||
@ -209,10 +209,10 @@ class Request:
|
||||
def get_finished_reason(self) -> FinishReason | None:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
assert input_id < len(self.mm_features)
|
||||
num_tokens = self.mm_features[input_id].mm_position.length
|
||||
return num_tokens
|
||||
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
|
||||
return num_embeds
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
|
||||
@ -169,9 +169,7 @@ from .utils import (
|
||||
MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache,
|
||||
gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -2209,10 +2207,7 @@ class GPUModelRunner(
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
||||
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
self.encoder_cache[mm_hash] = output
|
||||
logger.debug("Finish execute for mm hash %s", mm_hash)
|
||||
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
||||
|
||||
@ -2263,6 +2258,13 @@ class GPUModelRunner(
|
||||
num_encoder_tokens,
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
|
||||
)
|
||||
# If there are no embeddings in the current range, we skip
|
||||
# gathering the embeddings.
|
||||
if curr_embeds_start == curr_embeds_end:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
@ -2270,16 +2272,14 @@ class GPUModelRunner(
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
|
||||
else:
|
||||
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||
|
||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||
True if is_embed is None else is_embed
|
||||
)
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds_req.append(mm_embeds_item)
|
||||
|
||||
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
||||
@ -4508,31 +4508,8 @@ class GPUModelRunner(
|
||||
dummy_encoder_outputs,
|
||||
expected_num_items=max_mm_items_per_batch,
|
||||
)
|
||||
|
||||
# NOTE: This happens when encoder cache needs to store
|
||||
# the embeddings that encoder outputs are scattered onto.
|
||||
# In this case we create dummy embeddings of size
|
||||
# (max_tokens_for_modality, hidden_size) and scatter
|
||||
# encoder output into it.
|
||||
encoder_output_shape = dummy_encoder_outputs[0].shape
|
||||
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
|
||||
dummy_modality
|
||||
]
|
||||
if encoder_output_shape[0] < max_mm_tokens_per_item:
|
||||
encoder_hidden_size = encoder_output_shape[-1]
|
||||
expanded_outputs = []
|
||||
for output in dummy_encoder_outputs:
|
||||
expanded = output.new_zeros(
|
||||
(max_mm_tokens_per_item, encoder_hidden_size)
|
||||
)
|
||||
num_tokens = output.shape[0]
|
||||
expanded[:num_tokens].copy_(output)
|
||||
expanded_outputs.append(expanded)
|
||||
|
||||
dummy_encoder_outputs = expanded_outputs
|
||||
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
for i, output in enumerate(dummy_encoder_outputs):
|
||||
self.encoder_cache[f"tmp_{i}"] = output
|
||||
|
||||
# Add `is_profile` here to pre-allocate communication buffers
|
||||
hidden_states, last_hidden_states = self._dummy_run(
|
||||
|
||||
@ -4,10 +4,12 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
|
||||
)
|
||||
|
||||
|
||||
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||
def scatter_mm_placeholders(
|
||||
embeds: torch.Tensor,
|
||||
is_embed: torch.Tensor | None,
|
||||
@ -226,6 +231,7 @@ def scatter_mm_placeholders(
|
||||
return placeholders
|
||||
|
||||
|
||||
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||
def gather_mm_placeholders(
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: torch.Tensor | None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user