[Chore] Simplify logic of _execute_mm_encoder (#31222)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-12-24 10:15:16 +08:00 committed by GitHub
parent bc0a5a0c08
commit ca6a95ba25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,6 +61,7 @@ from vllm.model_executor.layers.rotary_embedding import (
)
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
SupportsMRoPE,
SupportsMultiModal,
SupportsXDRoPE,
@ -78,11 +79,7 @@ from vllm.model_executor.models.interfaces_base import (
is_text_generation_model,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
BatchedTensorInputs,
MultiModalKwargsItem,
PlaceholderRange,
)
from vllm.multimodal.inputs import BatchedTensorInputs, MultiModalKwargsItem
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
@ -2097,28 +2094,27 @@ class GPUModelRunner(
]
return logits_indices_padded
def _batch_mm_kwargs_from_scheduler(
def _batch_mm_inputs_from_scheduler(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[list[MultiModalKwargsItem], list[tuple[str, PlaceholderRange]]]:
"""Batch multimodal kwargs from scheduled encoder inputs.
) -> tuple[list[str], list[MultiModalKwargsItem]]:
"""Batch multimodal inputs from scheduled encoder inputs.
Args:
scheduler_output: The scheduler output containing scheduled encoder
inputs.
Returns:
A tuple of (mm_kwargs, req_ids_pos) where:
- mm_kwargs: List of multimodal kwargs items to be batched
- mm_hashes_pos: List of (mm_hash, position_info) tuples
A tuple of (mm_hashes, mm_kwargs) where:
- mm_hashes: List of multimodal hashes for each item
- mm_kwargs: List of multimodal kwargs for each item
"""
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return [], []
# Batch the multi-modal inputs.
mm_hashes = list[str]()
mm_kwargs = list[MultiModalKwargsItem]()
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
@ -2126,19 +2122,16 @@ class GPUModelRunner(
mm_feature = req_state.mm_features[mm_input_id]
if mm_feature.data is None:
continue
mm_hash = mm_feature.identifier
mm_kwargs.append(mm_feature.data)
mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
return mm_kwargs, mm_hashes_pos
mm_hashes.append(mm_feature.identifier)
mm_kwargs.append(mm_feature.data)
return mm_hashes, mm_kwargs
def _execute_mm_encoder(
self, scheduler_output: "SchedulerOutput"
) -> list[torch.Tensor]:
# Batch the multi-modal inputs using the helper method.
mm_kwargs, mm_hashes_pos = self._batch_mm_kwargs_from_scheduler(
scheduler_output
)
mm_hashes, mm_kwargs = self._batch_mm_inputs_from_scheduler(scheduler_output)
if not mm_kwargs:
return []
@ -2157,7 +2150,7 @@ class GPUModelRunner(
device=self.device,
pin_memory=self.pin_memory,
):
curr_group_outputs: list[torch.Tensor] = []
curr_group_outputs: MultiModalEmbeddings
# EVS-related change.
# (ekhvedchenia): Temporary hack to limit peak memory usage when
@ -2173,6 +2166,7 @@ class GPUModelRunner(
and modality == "video"
and num_items > 1
):
curr_group_outputs_lst = list[torch.Tensor]()
for video_mm_kwargs_item in filter(
lambda item: item.modality == "video", mm_kwargs
):
@ -2188,7 +2182,9 @@ class GPUModelRunner(
**micro_batch_mm_inputs
)
curr_group_outputs.extend(micro_batch_outputs)
curr_group_outputs_lst.extend(micro_batch_outputs)
curr_group_outputs = curr_group_outputs_lst
else:
# Run the encoder.
# `curr_group_outputs` is either of the following:
@ -2197,7 +2193,7 @@ class GPUModelRunner(
# 2. A list or tuple (length: num_items) of tensors,
# each of shape (feature_size, hidden_size) in case the feature
# size is dynamic depending on the input multimodal items.
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) # type: ignore[assignment]
curr_group_outputs = model.embed_multimodal(**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
@ -2206,7 +2202,7 @@ class GPUModelRunner(
encoder_outputs.extend(curr_group_outputs)
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
for mm_hash, output in zip(mm_hashes, encoder_outputs):
self.encoder_cache[mm_hash] = output
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)