mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:45:44 +08:00
[Misc] Require merge_by_field_config argument (#26214)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
44ea85137a
commit
736fbf4c89
@ -15,7 +15,6 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
@ -376,39 +375,12 @@ def argsort_mm_positions(
|
||||
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
||||
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@deprecated("`group_mm_inputs_by_modality` is superseded by "
|
||||
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
|
||||
"Please use `group_mm_kwargs_by_modality` instead.")
|
||||
def group_mm_inputs_by_modality(
|
||||
mm_inputs: list[MultiModalKwargsItems]
|
||||
) -> list[list[MultiModalKwargsItems]]:
|
||||
if not mm_inputs:
|
||||
return []
|
||||
|
||||
def modality_group_func(
|
||||
mm_input: MultiModalKwargsItems) -> Union[str, int]:
|
||||
# If the input has multiple modalities, return an id as the unique key
|
||||
# for the mm_input input.
|
||||
if len(mm_input) > 1:
|
||||
return id(mm_input)
|
||||
|
||||
elif len(mm_input) == 1:
|
||||
return next(iter(mm_input.keys()))
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
return [
|
||||
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
|
||||
]
|
||||
|
||||
|
||||
def group_mm_kwargs_by_modality(
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
*,
|
||||
device: torch.types.Device = None,
|
||||
pin_memory: bool = False,
|
||||
merge_by_field_config: bool = False,
|
||||
merge_by_field_config: Optional[bool] = None,
|
||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||
modality together into the same `MultiModalKwargs` instance.
|
||||
@ -421,15 +393,19 @@ def group_mm_kwargs_by_modality(
|
||||
Yields:
|
||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||
"""
|
||||
if merge_by_field_config is None:
|
||||
raise RuntimeError(
|
||||
"`group_mm_kwargs_by_modality` now requires "
|
||||
"`merge_by_field_config` arg, please update your model runner "
|
||||
"according to https://github.com/vllm-project/vllm/pull/25676.")
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
|
||||
|
||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||
items_lst = list(items)
|
||||
|
||||
# TODO: Enable `merge_by_field_config` for all models
|
||||
# to avoid creating an extra batch dimension (except for fields
|
||||
# that are meant to be stacked anyway).
|
||||
# We will also need to update each model to remove `flatten_bn`.
|
||||
# TODO: Deprecate `merge_by_field_config` once
|
||||
# we have migrated all in-tree models
|
||||
if merge_by_field_config:
|
||||
mm_kwargs_group: BatchedTensorInputs = dict(
|
||||
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
||||
|
||||
@ -7,10 +7,9 @@ from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||
@ -53,16 +52,6 @@ class CachedRequestState:
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
||||
return [
|
||||
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
|
||||
if f.data is not None
|
||||
]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
if self.prompt_token_ids is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user