mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Misc] Require merge_by_field_config argument (#26214)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
44ea85137a
commit
736fbf4c89
@ -15,7 +15,6 @@ import numpy as np
|
|||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image, UnidentifiedImageError
|
from PIL import Image, UnidentifiedImageError
|
||||||
from typing_extensions import deprecated
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.connections import HTTPConnection, global_http_connection
|
from vllm.connections import HTTPConnection, global_http_connection
|
||||||
@ -376,39 +375,12 @@ def argsort_mm_positions(
|
|||||||
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
||||||
|
|
||||||
|
|
||||||
# Temporary back-compatibility for plugins that define model runner
|
|
||||||
@deprecated("`group_mm_inputs_by_modality` is superseded by "
|
|
||||||
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
|
|
||||||
"Please use `group_mm_kwargs_by_modality` instead.")
|
|
||||||
def group_mm_inputs_by_modality(
|
|
||||||
mm_inputs: list[MultiModalKwargsItems]
|
|
||||||
) -> list[list[MultiModalKwargsItems]]:
|
|
||||||
if not mm_inputs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def modality_group_func(
|
|
||||||
mm_input: MultiModalKwargsItems) -> Union[str, int]:
|
|
||||||
# If the input has multiple modalities, return an id as the unique key
|
|
||||||
# for the mm_input input.
|
|
||||||
if len(mm_input) > 1:
|
|
||||||
return id(mm_input)
|
|
||||||
|
|
||||||
elif len(mm_input) == 1:
|
|
||||||
return next(iter(mm_input.keys()))
|
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
|
||||||
|
|
||||||
return [
|
|
||||||
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def group_mm_kwargs_by_modality(
|
def group_mm_kwargs_by_modality(
|
||||||
mm_kwargs: list[MultiModalKwargsItem],
|
mm_kwargs: list[MultiModalKwargsItem],
|
||||||
*,
|
*,
|
||||||
device: torch.types.Device = None,
|
device: torch.types.Device = None,
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
merge_by_field_config: bool = False,
|
merge_by_field_config: Optional[bool] = None,
|
||||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||||
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
||||||
modality together into the same `MultiModalKwargs` instance.
|
modality together into the same `MultiModalKwargs` instance.
|
||||||
@ -421,15 +393,19 @@ def group_mm_kwargs_by_modality(
|
|||||||
Yields:
|
Yields:
|
||||||
A tuple `(modality, num_items, grouped_kwargs)`.
|
A tuple `(modality, num_items, grouped_kwargs)`.
|
||||||
"""
|
"""
|
||||||
|
if merge_by_field_config is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"`group_mm_kwargs_by_modality` now requires "
|
||||||
|
"`merge_by_field_config` arg, please update your model runner "
|
||||||
|
"according to https://github.com/vllm-project/vllm/pull/25676.")
|
||||||
|
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
|
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
|
||||||
|
|
||||||
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
||||||
items_lst = list(items)
|
items_lst = list(items)
|
||||||
|
|
||||||
# TODO: Enable `merge_by_field_config` for all models
|
# TODO: Deprecate `merge_by_field_config` once
|
||||||
# to avoid creating an extra batch dimension (except for fields
|
# we have migrated all in-tree models
|
||||||
# that are meant to be stacked anyway).
|
|
||||||
# We will also need to update each model to remove `flatten_bn`.
|
|
||||||
if merge_by_field_config:
|
if merge_by_field_config:
|
||||||
mm_kwargs_group: BatchedTensorInputs = dict(
|
mm_kwargs_group: BatchedTensorInputs = dict(
|
||||||
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
||||||
|
|||||||
@ -7,10 +7,9 @@ from typing import Optional, cast
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import deprecated
|
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||||
@ -53,16 +52,6 @@ class CachedRequestState:
|
|||||||
def num_tokens(self) -> int:
|
def num_tokens(self) -> int:
|
||||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||||
|
|
||||||
# Temporary back-compatibility for plugins that define model runner
|
|
||||||
@property
|
|
||||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
|
||||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
|
||||||
def mm_inputs(self) -> list[MultiModalKwargsItems]:
|
|
||||||
return [
|
|
||||||
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
|
|
||||||
if f.data is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_token_id(self, idx: int) -> int:
|
def get_token_id(self, idx: int) -> int:
|
||||||
if idx < self.num_prompt_tokens:
|
if idx < self.num_prompt_tokens:
|
||||||
if self.prompt_token_ids is None:
|
if self.prompt_token_ids is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user