[Misc] Require merge_by_field_config argument (#26214)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-04 16:40:14 +08:00 committed by GitHub
parent 44ea85137a
commit 736fbf4c89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 45 deletions

View File

@ -15,7 +15,6 @@ import numpy as np
import numpy.typing as npt
import torch
from PIL import Image, UnidentifiedImageError
from typing_extensions import deprecated
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
@ -376,39 +375,12 @@ def argsort_mm_positions(
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
# Temporary back-compatibility for plugins that define model runner
@deprecated("`group_mm_inputs_by_modality` is superseded by "
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
"Please use `group_mm_kwargs_by_modality` instead.")
def group_mm_inputs_by_modality(
mm_inputs: list[MultiModalKwargsItems]
) -> list[list[MultiModalKwargsItems]]:
if not mm_inputs:
return []
def modality_group_func(
mm_input: MultiModalKwargsItems) -> Union[str, int]:
# If the input has multiple modalities, return an id as the unique key
# for the mm_input input.
if len(mm_input) > 1:
return id(mm_input)
elif len(mm_input) == 1:
return next(iter(mm_input.keys()))
raise AssertionError("This line should be unreachable.")
return [
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
]
def group_mm_kwargs_by_modality(
mm_kwargs: list[MultiModalKwargsItem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
merge_by_field_config: bool = False,
merge_by_field_config: Optional[bool] = None,
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
@ -421,15 +393,19 @@ def group_mm_kwargs_by_modality(
Yields:
A tuple `(modality, num_items, grouped_kwargs)`.
"""
if merge_by_field_config is None:
raise RuntimeError(
"`group_mm_kwargs_by_modality` now requires "
"`merge_by_field_config` arg, please update your model runner "
"according to https://github.com/vllm-project/vllm/pull/25676.")
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
# TODO: Enable `merge_by_field_config` for all models
# to avoid creating an extra batch dimension (except for fields
# that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`.
# TODO: Deprecate `merge_by_field_config` once
# we have migrated all in-tree models
if merge_by_field_config:
mm_kwargs_group: BatchedTensorInputs = dict(
MultiModalKwargsItems.from_seq(items_lst).get_data(

View File

@ -7,10 +7,9 @@ from typing import Optional, cast
import numpy as np
import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
@ -53,16 +52,6 @@ class CachedRequestState:
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
return [
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
if f.data is not None
]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
if self.prompt_token_ids is None: