mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:35:01 +08:00
[V1][Bugfix] Fix data item ordering in mixed-modality inference (#12259)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
c81081fece
commit
b197a5ccfd
@ -1,4 +1,5 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from itertools import groupby
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional, TypeVar, Union
|
from typing import TYPE_CHECKING, Optional, TypeVar, Union
|
||||||
from urllib.parse import ParseResult, urlparse
|
from urllib.parse import ParseResult, urlparse
|
||||||
@ -26,7 +27,7 @@ _M = TypeVar("_M")
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .hasher import MultiModalHashDict
|
from .hasher import MultiModalHashDict
|
||||||
from .inputs import MultiModalPlaceholderDict
|
from .inputs import MultiModalKwargs, MultiModalPlaceholderDict
|
||||||
|
|
||||||
|
|
||||||
class MediaConnector:
|
class MediaConnector:
|
||||||
@ -477,3 +478,34 @@ def merge_and_sort_multimodal_metadata(
|
|||||||
merged_hashes = None
|
merged_hashes = None
|
||||||
|
|
||||||
return sorted_modalities, merged_placeholders, merged_hashes
|
return sorted_modalities, merged_placeholders, merged_hashes
|
||||||
|
|
||||||
|
|
||||||
|
def group_mm_inputs_by_modality(
|
||||||
|
mm_inputs: list["MultiModalKwargs"]) -> list[list["MultiModalKwargs"]]:
|
||||||
|
"""Group consecutive MultiModalKwargs from mm_inputs with the same modality
|
||||||
|
together into the same list for batching purpose. For MultiModalKwargs with
|
||||||
|
multiple modalities, put them into their own list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mm_inputs: List of MultiModalKwargs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[list[MultiModalKwargs]]: List of list of MultiModalKwargs, each
|
||||||
|
inner list contains consecutive MultiModalKwargs with same modality, or
|
||||||
|
one with multimodal modalities.
|
||||||
|
"""
|
||||||
|
if not mm_inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def modality_group_func(mm_input: "MultiModalKwargs") -> Union[str, int]:
|
||||||
|
# If the input has multiple modalities, return a id as the unique key
|
||||||
|
# for the mm_input input.
|
||||||
|
if len(mm_input.modalities) > 1:
|
||||||
|
return id(mm_input)
|
||||||
|
|
||||||
|
# Otherwise return the modality string
|
||||||
|
return list(mm_input.modalities)[0]
|
||||||
|
|
||||||
|
return [
|
||||||
|
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
|
||||||
|
]
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
LayerBlockType, cdiv, is_pin_memory_available)
|
LayerBlockType, cdiv, is_pin_memory_available)
|
||||||
@ -629,20 +630,35 @@ class GPUModelRunner:
|
|||||||
for input_id in encoder_input_ids:
|
for input_id in encoder_input_ids:
|
||||||
mm_inputs.append(req_state.mm_inputs[input_id])
|
mm_inputs.append(req_state.mm_inputs[input_id])
|
||||||
req_input_ids.append((req_id, input_id))
|
req_input_ids.append((req_id, input_id))
|
||||||
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
|
|
||||||
|
# Batch mm inputs as much as we can: if a request in the batch has
|
||||||
|
# multiple modalities or a different modality than the previous one,
|
||||||
|
# we process it separately to preserve item order.
|
||||||
|
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
|
||||||
|
# in the same batch while still being able to benefit from batching
|
||||||
|
# multimodal inputs. The proper solution should be reordering the
|
||||||
|
# encoder outputs.
|
||||||
|
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
||||||
|
|
||||||
|
encoder_outputs = []
|
||||||
|
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||||
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
# Run the encoder.
|
# Run the encoder.
|
||||||
# `encoder_outputs` is either of the following:
|
# `curr_group_outputs` is either of the following:
|
||||||
# 1. A tensor of shape [num_images, feature_size, hidden_size]
|
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
||||||
# in case when feature_size is fixed across all images.
|
# in case feature_size is fixed across all multimodal items.
|
||||||
# 2. A list (length: num_images) of tensors, each of shape
|
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
||||||
# [feature_size, hidden_size] in case when the feature size is
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||||
# dynamic depending on input images.
|
# depending on the input multimodal items.
|
||||||
encoder_outputs = self.model.get_multimodal_embeddings(
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||||
**batched_mm_inputs)
|
**batched_mm_inputs)
|
||||||
|
|
||||||
|
for output in curr_group_outputs:
|
||||||
|
encoder_outputs.append(output)
|
||||||
|
|
||||||
# Cache the encoder outputs.
|
# Cache the encoder outputs.
|
||||||
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
|
||||||
if req_id not in self.encoder_cache:
|
if req_id not in self.encoder_cache:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user