[Multimodal] Remove legacy multimodal fields in favor of MultiModalFeatureSpec (#24548)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng 2025-09-12 06:42:23 -07:00 committed by GitHub
parent 72fc8aa412
commit 0377802c20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 102 additions and 116 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
@ -9,8 +10,17 @@ class MockRequest:
def __init__(self, request_id, mm_hashes, token_counts):
self.request_id = request_id
self.mm_hashes = mm_hashes
self._token_counts = token_counts
self.mm_features = []
for i, mm_hash in enumerate(mm_hashes):
feature = MultiModalFeatureSpec(
data=None,
modality="image",
identifier=mm_hash,
mm_position=PlaceholderRange(offset=0,
length=self._token_counts[i]),
)
self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int:
return self._token_counts[input_id]

View File

@ -64,9 +64,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
mm_kwargs=[],
mm_hashes=[],
mm_positions=[],
mm_features=[],
sampling_params=SamplingParams(),
pooling_params=PoolingParams(),
block_ids=([0], ), # block_ids should be tuple[list[int]]

View File

@ -203,9 +203,7 @@ def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_kwargs=[],
mm_positions=[],
mm_hashes=[],
mm_features=[],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),

View File

@ -118,9 +118,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
NewRequestData(
req_id=req_id,
prompt_token_ids=[1, 2, 3],
mm_kwargs=[],
mm_hashes=[],
mm_positions=[],
mm_features=[],
sampling_params=SamplingParams(),
pooling_params=None,
block_ids=([0], ),

View File

@ -300,11 +300,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
total_need_load = 0
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in self._requests_need_load:
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False,
mm_hashes=new_req.mm_hashes)
meta.add_request(
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in new_req.mm_features])
total_need_load += 1
else:
# NOTE: here, we set the store and load being exclusive,
@ -312,11 +313,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
# NOTE(rob): for this debug implementation, we only cache
# the original prompt tokens.
if not self._found_match_for_request(new_req):
meta.add_request(token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True,
mm_hashes=new_req.mm_hashes)
meta.add_request(
token_ids=new_req.prompt_token_ids,
block_ids=new_req.block_ids[0],
block_size=self._block_size,
is_store=True,
mm_hashes=[f.identifier for f in new_req.mm_features])
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
@ -341,11 +343,12 @@ class SharedStorageConnector(KVConnectorBase_V1):
# of the block_ids for the request.
block_ids = new_block_ids[0]
meta.add_request(token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=request.mm_hashes)
meta.add_request(
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size,
is_store=False,
mm_hashes=[f.identifier for f in request.mm_features])
total_need_load += 1
assert total_need_load == len(self._requests_need_load)
@ -364,10 +367,10 @@ class SharedStorageConnector(KVConnectorBase_V1):
"""
num_tokens_to_check = align_to_block_size(
len(request.prompt_token_ids) - 1, self._block_size)
foldername = self._generate_foldername_debug(torch.tensor(
request.prompt_token_ids)[:num_tokens_to_check],
request.mm_hashes,
create_folder=False)
foldername = self._generate_foldername_debug(
torch.tensor(request.prompt_token_ids)[:num_tokens_to_check],
[f.identifier for f in request.mm_features],
create_folder=False)
return os.path.exists(foldername)
def _generate_foldername_debug(

View File

@ -86,7 +86,7 @@ class EncoderCacheManager:
Returns:
True if the encoder output for this input is already cached
"""
mm_hash = request.mm_hashes[input_id]
mm_hash = request.mm_features[input_id].identifier
# Not cached at all
if mm_hash not in self.cached:
return False
@ -167,7 +167,7 @@ class EncoderCacheManager:
This method assumes can_allocate() returned True for the same input.
"""
mm_hash = request.mm_hashes[input_id]
mm_hash = request.mm_features[input_id].identifier
request_id = request.request_id
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
@ -193,8 +193,8 @@ class EncoderCacheManager:
"""
return {
input_id
for input_id in range(len(request.mm_hashes))
if request.mm_hashes[input_id] in self.cached
for input_id in range(len(request.mm_features))
if request.mm_features[input_id].identifier in self.cached
}
def free_encoder_input(self, request: Request, input_id: int) -> None:
@ -208,7 +208,7 @@ class EncoderCacheManager:
`can_allocate`).
"""
req_id = request.request_id
mm_hash = request.mm_hashes[input_id]
mm_hash = request.mm_features[input_id].identifier
# The mm_hash not in cache or the req_id set is empty
if not self.cached.get(mm_hash, None):
return

View File

@ -418,9 +418,9 @@ def need_extra_keys(request: Request) -> bool:
# Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID.
# Request with provided cache salt need to include the salt.
return bool(request.mm_hashes) or (request.lora_request
is not None) or (request.cache_salt
is not None)
return bool(request.mm_features) or (request.lora_request
is not None) or (request.cache_salt
is not None)
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
@ -442,32 +442,28 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
"""
extra_keys: list[Any] = []
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
mm_features = request.mm_features
if not mm_features:
return extra_keys, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you did not enable MM hashing. "
"Please set `mm_processor_cache_gb > 0`.")
# Note that we assume mm_positions is sorted by offset.
# Note that we assume mm_features are sorted by mm_position.offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
last_pos = mm_features[-1].mm_position
if last_pos.offset + last_pos.length < start_token_idx:
return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
assert -start_mm_idx <= len(mm_features)
start_mm_idx = len(mm_features) + start_mm_idx
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx].offset
length = mm_positions[curr_mm_idx].length
while mm_features and curr_mm_idx < len(mm_features):
mm_feature = mm_features[curr_mm_idx]
assert mm_feature.identifier is not None
offset = mm_feature.mm_position.offset
length = mm_feature.mm_position.length
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
@ -475,7 +471,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
continue
# The block contains the current mm input.
extra_keys.append(mm_hashes[curr_mm_idx])
extra_keys.append(mm_feature.identifier)
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,

View File

@ -15,7 +15,7 @@ if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
@ -27,9 +27,7 @@ class NewRequestData:
req_id: str
prompt_token_ids: list[int]
mm_kwargs: list[MultiModalKwargsItem]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
@ -45,9 +43,7 @@ class NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
mm_kwargs=request.mm_kwargs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
@ -59,9 +55,7 @@ class NewRequestData:
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_kwargs={self.mm_kwargs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
@ -73,9 +67,7 @@ class NewRequestData:
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_kwargs={self.mm_kwargs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"

View File

@ -736,18 +736,18 @@ class Scheduler(SchedulerInterface):
if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_compute_budget
encoder_inputs_to_schedule: list[int] = []
mm_positions = request.mm_positions
assert mm_positions is not None
assert len(mm_positions) > 0
mm_features = request.mm_features
assert mm_features is not None
assert len(mm_features) > 0
# NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set()
num_tokens_to_schedule = 0
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@ -778,7 +778,7 @@ class Scheduler(SchedulerInterface):
if not self.is_encoder_decoder:
# We are not using the encoder cache for encoder-decoder models,
# yet.
if request.mm_hashes[i] in mm_hashes_to_schedule:
if request.mm_features[i].identifier in mm_hashes_to_schedule:
# The same encoder input has already been scheduled in the
# current step.
continue
@ -820,7 +820,7 @@ class Scheduler(SchedulerInterface):
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
mm_hashes_to_schedule.add(request.mm_hashes[i])
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
encoder_inputs_to_schedule.append(i)
return (
@ -1048,9 +1048,9 @@ class Scheduler(SchedulerInterface):
# Here, we use list(set) to avoid modifying the set while iterating
# over it.
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
mm_feature = request.mm_features[input_id]
start_pos = mm_feature.mm_position.offset
num_tokens = mm_feature.mm_position.length
if self.is_encoder_decoder and request.num_computed_tokens > 0:
# With Whisper, as soon as we've generated a single token,
# we know we're done with the encoder input. Cross Attention

View File

@ -91,11 +91,6 @@ class Request:
self.mm_features = mm_features or []
self.num_encoder_inputs = len(self.mm_features)
self.has_encoder_inputs = self.num_encoder_inputs > 0
# TODO(sfeng33): Remove these legacy fields after clearing out all
# references in scheduler and model runner
self.mm_positions = [f.mm_position for f in self.mm_features]
self.mm_kwargs = [f.data for f in self.mm_features]
self.mm_hashes = [f.identifier for f in self.mm_features]
# Read-only views
# Prevent directly appending to these lists since
@ -180,8 +175,8 @@ class Request:
return RequestStatus.get_finished_reason(self.status)
def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_positions)
num_tokens = self.mm_positions[input_id].length
assert input_id < len(self.mm_features)
num_tokens = self.mm_features[input_id].mm_position.length
return num_tokens
def record_event(

View File

@ -10,8 +10,7 @@ import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
@ -31,9 +30,7 @@ class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
mm_kwargs: list[MultiModalKwargsItem]
mm_positions: list[PlaceholderRange]
mm_hashes: list[str]
mm_features: list[MultiModalFeatureSpec]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
@ -60,7 +57,8 @@ class CachedRequestState:
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([item]) for item in self.mm_kwargs
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
if f.data is not None
]
def get_token_id(self, idx: int) -> int:

View File

@ -555,9 +555,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
mm_hashes=new_req_data.mm_hashes,
mm_features=new_req_data.mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
@ -698,7 +696,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_item in req_state.mm_kwargs:
for mm_feature in req_state.mm_features:
mm_item = mm_feature.data
if mm_item is None:
continue
mm_input = mm_item.get_data()
if (t := mm_input.get("image_grid_thw")) is not None:
image_grid_thw.append(t.tolist())
@ -731,7 +732,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_kwargs = list[MultiModalKwargsItem]()
for req in scheduler_output.scheduled_new_reqs:
mm_kwargs.extend(req.mm_kwargs)
for feature in req.mm_features:
if feature.data is not None:
mm_kwargs.append(feature.data)
# Input all modalities at once
mm_kwargs_combined: BatchedTensorInputs = {}
@ -1361,10 +1364,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
mm_feature = req_state.mm_features[mm_input_id]
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
return mm_kwargs, mm_hashes_pos
@ -1426,9 +1429,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes
for i, pos_info in enumerate(mm_positions):
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
@ -1451,7 +1453,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
assert start_idx < end_idx
mm_hash = mm_hashes[i]
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."

View File

@ -387,9 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
mm_hashes=new_req_data.mm_hashes,
mm_features=new_req_data.mm_features,
sampling_params=sampling_params,
pooling_params=None,
generator=None,
@ -822,10 +820,10 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
mm_feature = req_state.mm_features[mm_input_id]
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@ -883,13 +881,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes
# TODO unroll loop and assume/enforce --disable_chunked_mm_input
# NOTE (NickLucche) here we diverge from logic in other runners, as
# we assume to only have whole mm items to process. Hence we avoid
# the intrinsic dynamism that `gather_mm_placeholders` introduces.
for i, pos_info in enumerate(mm_positions):
for mm_feature in req_state.mm_features:
pos_info = mm_feature.mm_position
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
@ -904,8 +901,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
mm_hash = mm_hashes[i]
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."