mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision (#11685)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e20c92bb61
commit
91b361ae89
@ -647,7 +647,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
- `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
-
|
||||
- ✅︎
|
||||
* - `MiniCPMV`
|
||||
- MiniCPM-V
|
||||
- T + I<sup>E+</sup>
|
||||
|
||||
@ -2,16 +2,22 @@ import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
from typing import Dict, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image, ImageChops
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import (MediaConnector,
|
||||
merge_and_sort_multimodal_metadata,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.hasher import MultiModalHashDict
|
||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
TEST_IMAGE_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
|
||||
@ -191,3 +197,204 @@ def test_repeat_and_pad_placeholder_tokens(model):
|
||||
assert new_prompt == expected_prompt
|
||||
assert new_token_ids == expected_token_ids
|
||||
assert ranges == expected_ranges
|
||||
|
||||
|
||||
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
|
||||
class TestCase(NamedTuple):
|
||||
mm_positions: "MultiModalPlaceholderDict"
|
||||
mm_hashes: Optional["MultiModalHashDict"]
|
||||
expected_modalities: list[str]
|
||||
expected_ranges: list[PlaceholderRange]
|
||||
expected_hashes: Optional[list[str]]
|
||||
|
||||
|
||||
def test_merge_and_sort_multimodal_metadata():
|
||||
|
||||
test_cases = [
|
||||
# Single modality should return result as is but flattened
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=3, length=2),
|
||||
]
|
||||
},
|
||||
mm_hashes={"image": ["hash1", "hash2"]},
|
||||
expected_modalities=["image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=3, length=2),
|
||||
],
|
||||
expected_hashes=["hash1", "hash2"],
|
||||
),
|
||||
|
||||
# Single modality without hashes return None for mm hash.
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=2),
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=["image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=2),
|
||||
],
|
||||
expected_hashes=None,
|
||||
),
|
||||
|
||||
# Multiple modalities with hashes should return sorted modalities
|
||||
# and flattened ranges and hashes.
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=7, length=4),
|
||||
PlaceholderRange(offset=11, length=5),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
]
|
||||
},
|
||||
mm_hashes={
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1", "audio_hash2"],
|
||||
},
|
||||
expected_modalities=["audio", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=7, length=4),
|
||||
PlaceholderRange(offset=11, length=5),
|
||||
],
|
||||
expected_hashes=[
|
||||
"audio_hash1", "audio_hash2", "image_hash1", "image_hash2"
|
||||
],
|
||||
),
|
||||
|
||||
# Multiple modalities without hashes should return sorted modalities
|
||||
# and flattened ranges and None.
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=7, length=4),
|
||||
PlaceholderRange(offset=11, length=5),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=["audio", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=7, length=4),
|
||||
PlaceholderRange(offset=11, length=5),
|
||||
],
|
||||
expected_hashes=None,
|
||||
),
|
||||
|
||||
# Three modalities
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=15, length=7),
|
||||
PlaceholderRange(offset=22, length=8),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
],
|
||||
"video": [
|
||||
PlaceholderRange(offset=3, length=4),
|
||||
PlaceholderRange(offset=7, length=5),
|
||||
PlaceholderRange(offset=12, length=6),
|
||||
]
|
||||
},
|
||||
mm_hashes={
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1"],
|
||||
"video": ["video_hash1", "video_hash2", "video_hash3"]
|
||||
},
|
||||
expected_modalities=["audio", "video", "image"],
|
||||
expected_ranges=[
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=3, length=4),
|
||||
PlaceholderRange(offset=7, length=5),
|
||||
PlaceholderRange(offset=12, length=6),
|
||||
PlaceholderRange(offset=15, length=7),
|
||||
PlaceholderRange(offset=22, length=8),
|
||||
],
|
||||
expected_hashes=[
|
||||
"audio_hash1", "video_hash1", "video_hash2", "video_hash3",
|
||||
"image_hash1", "image_hash2"
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
for (mm_positions, mm_hashes, expected_modalities, expected_ranges,
|
||||
expected_hashes) in test_cases:
|
||||
modalities, ranges, hashes = merge_and_sort_multimodal_metadata(
|
||||
mm_positions, mm_hashes)
|
||||
|
||||
assert modalities == expected_modalities
|
||||
assert ranges == expected_ranges
|
||||
assert hashes == expected_hashes
|
||||
|
||||
|
||||
def test_merge_and_sort_multimodal_metadata_with_interleaving():
|
||||
|
||||
test_cases = [
|
||||
|
||||
# <image> <audio> <image> <audio>
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=4),
|
||||
PlaceholderRange(offset=8, length=2),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
PlaceholderRange(offset=11, length=4),
|
||||
]
|
||||
},
|
||||
mm_hashes={
|
||||
"image": ["image_hash1", "image_hash2"],
|
||||
"audio": ["audio_hash1", "audio_hash2"],
|
||||
},
|
||||
expected_modalities=[],
|
||||
expected_ranges=[],
|
||||
expected_hashes=None,
|
||||
),
|
||||
|
||||
# <image> <image> <video> <audio> <image>
|
||||
TestCase(
|
||||
mm_positions={
|
||||
"image": [
|
||||
PlaceholderRange(offset=0, length=2),
|
||||
PlaceholderRange(offset=2, length=3),
|
||||
PlaceholderRange(offset=20, length=4),
|
||||
],
|
||||
"audio": [
|
||||
PlaceholderRange(offset=5, length=2),
|
||||
],
|
||||
"video": [
|
||||
PlaceholderRange(offset=8, length=5),
|
||||
]
|
||||
},
|
||||
mm_hashes=None,
|
||||
expected_modalities=[],
|
||||
expected_ranges=[],
|
||||
expected_hashes=None,
|
||||
),
|
||||
]
|
||||
|
||||
for case in test_cases:
|
||||
with pytest.raises(ValueError) as ex_info:
|
||||
merge_and_sort_multimodal_metadata(case.mm_positions,
|
||||
case.mm_hashes)
|
||||
|
||||
assert "Interleaved mixed-modality" in str(ex_info.value)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
@ -14,14 +14,18 @@ def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None):
|
||||
if mm_positions is None:
|
||||
multi_modal_inputs = None
|
||||
else:
|
||||
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
|
||||
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
inputs=token_inputs(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_placeholders={"image": mm_positions}
|
||||
if mm_positions else None,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
),
|
||||
prompt=None,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
"""Compare the with and without prefix caching."""
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import token_inputs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
@ -13,12 +12,18 @@ def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None):
|
||||
if mm_positions is None:
|
||||
multi_modal_inputs = None
|
||||
else:
|
||||
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
|
||||
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
inputs=token_inputs(prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_placeholders={"image": mm_positions}
|
||||
if mm_positions else None,
|
||||
multi_modal_hashes=mm_hashes),
|
||||
prompt=None,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
eos_token_id=100,
|
||||
arrival_time=0,
|
||||
|
||||
@ -39,8 +39,12 @@ class SupportsMultiModal(Protocol):
|
||||
|
||||
The output embeddings must be one of the following formats:
|
||||
- A list or tuple of 2D tensors, where each tensor corresponds to
|
||||
each input image.
|
||||
each input multimodal data item (e.g, image).
|
||||
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
|
||||
|
||||
NOTE: The returned multimodal embeddings must be in the same order as
|
||||
the appearances of their corresponding multimodal data item in the
|
||||
input prompt.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@ -35,6 +35,9 @@ from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
|
||||
# For profile run
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
|
||||
|
||||
class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
@ -223,8 +226,10 @@ class LlavaOnevisionProfilingInfo(LlavaOnevisionProcessingMixin,
|
||||
max_image_tokens = self._get_max_image_tokens() * max_images
|
||||
max_total_frames = self._get_max_video_frames(seq_len -
|
||||
max_image_tokens)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
|
||||
return max(max_total_frames // max(max_videos, 1), 1)
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
def _get_max_video_tokens(self, seq_len: int) -> int:
|
||||
target_width, target_height = self._get_image_size_with_most_features()
|
||||
@ -558,13 +563,15 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
if "pixel_values" in kwargs:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
|
||||
if "pixel_values_videos" in kwargs:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key == "pixel_values" and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key == "pixel_values_videos" and "videos" not in modalities: # noqa E501
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
@ -824,21 +831,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if not modalities:
|
||||
return None
|
||||
|
||||
# We make a tuple of each embedding with its modality string. This is a
|
||||
# temporary workaround for models to handle mixed modalities when
|
||||
# get_multimodal_embeddings and get_input_embeddings are called
|
||||
# separately.
|
||||
# TODO(ywang96): Add support for mixed-modality inference for v1.
|
||||
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
|
||||
# The result multimodal_embeddings is tuple of tensors, with each
|
||||
# tensor correspoending to a multimodal data item (image or video).
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
|
||||
if "images" in modalities:
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings.append((vision_embeddings, "image"))
|
||||
if "videos" in modalities:
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_pixels(video_input)
|
||||
multimodal_embeddings.append((video_embeddings, "video"))
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += tuple(vision_embeddings)
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_video_pixels(video_input)
|
||||
multimodal_embeddings += tuple(video_embeddings)
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@ -850,15 +857,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
for embeddings, modality in multimodal_embeddings:
|
||||
if modality == "image":
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, embeddings,
|
||||
self.config.image_token_index)
|
||||
if modality == "video":
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, embeddings,
|
||||
self.config.video_token_index)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
[self.config.image_token_index, self.config.video_token_index])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
||||
@ -972,8 +972,6 @@ def image_input_mapper_for_molmo(
|
||||
assert len(data) == 1, "Molmo supports only one image per prompt."
|
||||
data = data[0]
|
||||
|
||||
# Remove unused dummy PIL image
|
||||
data.pop('raw_mm_data', None)
|
||||
return MultiModalKwargs(data)
|
||||
|
||||
|
||||
@ -1019,7 +1017,6 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
|
||||
dummy_imgdata = {
|
||||
"images": out["images"],
|
||||
"image_input_idx": out["image_input_idx"],
|
||||
"raw_mm_data": dummy_image,
|
||||
}
|
||||
if "image_masks" in out:
|
||||
dummy_imgdata["image_masks"] = out["image_masks"]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from .base import MultiModalPlaceholderMap, MultiModalPlugin
|
||||
from .hasher import MultiModalHashDict, MultiModalHasher
|
||||
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
||||
MultiModalDataDict, MultiModalKwargs,
|
||||
MultiModalPlaceholderDict, NestedTensors)
|
||||
@ -18,6 +19,8 @@ __all__ = [
|
||||
"ModalityData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalHashDict",
|
||||
"MultiModalHasher",
|
||||
"MultiModalKwargs",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalPlaceholderMap",
|
||||
|
||||
100
vllm/multimodal/hasher.py
Normal file
100
vllm/multimodal/hasher.py
Normal file
@ -0,0 +1,100 @@
|
||||
import pickle
|
||||
from typing import TYPE_CHECKING, Iterable, Mapping, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from blake3 import blake3
|
||||
from PIL import Image
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.inputs import TokensPrompt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MultiModalHashDict = Mapping[str, list[str]]
|
||||
"""
|
||||
A dictionary containing hashes for items in each modality.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalHasher:
|
||||
|
||||
@classmethod
|
||||
def serialize_item(cls, obj: object) -> bytes:
|
||||
# Simple cases
|
||||
if isinstance(obj, str):
|
||||
return obj.encode("utf-8")
|
||||
if isinstance(obj, bytes):
|
||||
return obj
|
||||
if isinstance(obj, Image.Image):
|
||||
return obj.tobytes()
|
||||
|
||||
# Convertible to NumPy arrays
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj = obj.numpy()
|
||||
if isinstance(obj, (int, float)):
|
||||
obj = np.array(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tobytes()
|
||||
|
||||
logger.warning(
|
||||
"No serialization method found for %s. "
|
||||
"Falling back to pickle.", type(obj))
|
||||
|
||||
return pickle.dumps(obj)
|
||||
|
||||
@classmethod
|
||||
def item_to_bytes(
|
||||
cls,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[tuple[bytes, bytes]]:
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
yield from cls.item_to_bytes(f"{key}.{i}", elem)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
yield from cls.item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
key_bytes = cls.serialize_item(key)
|
||||
value_bytes = cls.serialize_item(obj)
|
||||
yield key_bytes, value_bytes
|
||||
|
||||
@classmethod
|
||||
def hash_kwargs(cls, **kwargs: object) -> str:
|
||||
hasher = blake3()
|
||||
|
||||
for k, v in kwargs.items():
|
||||
for k_bytes, v_bytes in cls.item_to_bytes(k, v):
|
||||
hasher.update(k_bytes)
|
||||
hasher.update(v_bytes)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
@classmethod
|
||||
def hash_prompt_mm_data(
|
||||
cls, prompt: "TokensPrompt") -> Optional["MultiModalHashDict"]:
|
||||
"""Hash multimodal data in the user input prompt if they exist."""
|
||||
|
||||
if "multi_modal_data" not in prompt:
|
||||
return None
|
||||
|
||||
mm_data = prompt["multi_modal_data"]
|
||||
if not mm_data:
|
||||
# mm_data can be None or an empty dict.
|
||||
return None
|
||||
|
||||
mm_items = {
|
||||
modality: items if isinstance(items, list) else [items]
|
||||
for modality, items in mm_data.items()
|
||||
}
|
||||
|
||||
mm_hashes = {
|
||||
modality: [cls.hash_kwargs(**{modality: item}) for item in items]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
|
||||
return mm_hashes
|
||||
@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
|
||||
final)
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||
Union, cast, final)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -14,6 +14,9 @@ from typing_extensions import NotRequired, TypeAlias
|
||||
|
||||
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hasher import MultiModalHashDict
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
HfImageItem: TypeAlias = Union[Image, np.ndarray, torch.Tensor]
|
||||
@ -513,7 +516,7 @@ class MultiModalInputsV2(TypedDict):
|
||||
mm_kwargs: MultiModalKwargs
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
mm_hashes: NotRequired[list[str]]
|
||||
mm_hashes: NotRequired[Optional["MultiModalHashDict"]]
|
||||
"""The hashes of the multi-modal data."""
|
||||
|
||||
mm_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import pickle
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
@ -7,18 +6,16 @@ from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from blake3 import blake3
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
|
||||
from vllm import envs
|
||||
from vllm.inputs import DummyData, InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
||||
encode_tokens)
|
||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
||||
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
@ -486,56 +483,6 @@ class ProcessingCache:
|
||||
logger.debug("ProcessingCache: hit_ratio = %.2f",
|
||||
cache_stats.hit_ratio)
|
||||
|
||||
def _serialize_item(self, obj: object) -> bytes:
|
||||
# Simple cases
|
||||
if isinstance(obj, str):
|
||||
return obj.encode("utf-8")
|
||||
if isinstance(obj, bytes):
|
||||
return obj
|
||||
if isinstance(obj, Image.Image):
|
||||
return obj.tobytes()
|
||||
|
||||
# Convertible to NumPy arrays
|
||||
if isinstance(obj, torch.Tensor):
|
||||
obj = obj.numpy()
|
||||
if isinstance(obj, (int, float)):
|
||||
obj = np.array(obj)
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tobytes()
|
||||
|
||||
logger.warning(
|
||||
"No serialization method found for %s. "
|
||||
"Falling back to pickle.", type(obj))
|
||||
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def _item_to_bytes(
|
||||
self,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[tuple[bytes, bytes]]:
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
yield from self._item_to_bytes(f"{key}.{i}", elem)
|
||||
elif isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
yield from self._item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
key_bytes = self._serialize_item(key)
|
||||
value_bytes = self._serialize_item(obj)
|
||||
yield key_bytes, value_bytes
|
||||
|
||||
def _hash_kwargs(self, **kwargs: object) -> str:
|
||||
hasher = blake3()
|
||||
|
||||
for k, v in kwargs.items():
|
||||
for k_bytes, v_bytes in self._item_to_bytes(k, v):
|
||||
hasher.update(k_bytes)
|
||||
hasher.update(v_bytes)
|
||||
|
||||
return hasher.hexdigest()
|
||||
|
||||
def get(
|
||||
self,
|
||||
model_id: str,
|
||||
@ -554,9 +501,9 @@ class ProcessingCache:
|
||||
"""
|
||||
self._maybe_log_cache_stats()
|
||||
|
||||
cache_key = self._hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
return self._cache.get(cache_key)
|
||||
|
||||
def put(
|
||||
@ -571,9 +518,9 @@ class ProcessingCache:
|
||||
Put a processed multi-modal item into the cache
|
||||
according to its dependencies (see :meth:`get`).
|
||||
"""
|
||||
cache_key = self._hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
self._cache.put(cache_key, output_kwargs)
|
||||
|
||||
|
||||
@ -1049,6 +996,24 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
||||
"""
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
|
||||
# Create MM hashes (only used in V1)
|
||||
# TODO: Use these hash keys for caching operations in apply_hf_processor
|
||||
# instead of rehashing.
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
model_id = self.ctx.model_config.model
|
||||
mm_hashes = {
|
||||
modality: [
|
||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
else:
|
||||
mm_hashes = None
|
||||
|
||||
prompt_ids, mm_kwargs = self._cached_apply_hf_processor(
|
||||
prompt_text,
|
||||
mm_items,
|
||||
@ -1122,6 +1087,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
||||
prompt=prompt_text,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_hashes=mm_hashes,
|
||||
mm_placeholders=mm_placeholder_ranges,
|
||||
)
|
||||
|
||||
@ -1174,7 +1140,9 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
|
||||
"tokens.")
|
||||
|
||||
total_len = len(prompt_token_ids)
|
||||
if total_len > seq_len:
|
||||
|
||||
# V0 does not support chunked prefill.
|
||||
if total_len > seq_len and not envs.VLLM_USE_V1:
|
||||
logger.warning(
|
||||
"The context length (%d) of the model is too short "
|
||||
"to hold the multi-modal embeddings in the worst case "
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import numpy as np
|
||||
@ -25,6 +25,10 @@ cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .hasher import MultiModalHashDict
|
||||
from .inputs import MultiModalPlaceholderDict
|
||||
|
||||
|
||||
class MediaConnector:
|
||||
|
||||
@ -437,3 +441,83 @@ def consecutive_placeholder_ranges(
|
||||
PlaceholderRange(offset=initial_offset + i * item_size,
|
||||
length=item_size) for i in range(num_items)
|
||||
]
|
||||
|
||||
|
||||
def merge_and_sort_multimodal_metadata(
|
||||
mm_positions: "MultiModalPlaceholderDict",
|
||||
mm_hashes: Optional["MultiModalHashDict"],
|
||||
) -> tuple[list[str], list[PlaceholderRange], Optional[list[str]]]:
|
||||
"""Given a MultiModalPlaceholderDict, merge all PlaceholderRange
|
||||
objects from all available modalities into a single list of
|
||||
PlaceholderRange, sorted by their offset (starting index in the input
|
||||
sequence) in the ascending order.
|
||||
|
||||
Optionally if a MultiModalHashDict is given, same operation will be
|
||||
applied to the object and the sorted list of hashes will be returned.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input prompt has interleaved placeholders from
|
||||
different modalities (e.g, "<image><audio><image> Describe the
|
||||
content.")
|
||||
|
||||
Returns:
|
||||
list[str]: Sorted list of involved modalities.
|
||||
list[PlaceholderRange]: Sorted list of all PlaceholdeRanges from
|
||||
mm_positions.
|
||||
Optional[list[str]]: Sorted list of all hashes from mm_hashes if
|
||||
given, None otherwise.
|
||||
"""
|
||||
|
||||
modalities = list(mm_positions.keys())
|
||||
|
||||
assert len(modalities) > 0, "No modalities found in the mm_positions."
|
||||
|
||||
# For single modality, placeholder ranges and hashes are already sorted
|
||||
# so we can return the list directly.
|
||||
if len(modalities) == 1:
|
||||
if mm_hashes is None:
|
||||
return modalities, list(mm_positions[modalities[0]]), None
|
||||
else:
|
||||
return modalities, list(mm_positions[modalities[0]]), list(
|
||||
mm_hashes[modalities[0]])
|
||||
|
||||
placeholder_lists_with_modality = [(modality, mm_positions[modality])
|
||||
for modality in modalities]
|
||||
|
||||
if mm_hashes is None:
|
||||
sorted_placeholder_lists = sorted(placeholder_lists_with_modality,
|
||||
key=lambda x: x[1][0]['offset'])
|
||||
sorted_hash_lists = None
|
||||
else:
|
||||
hashes_lists = [
|
||||
mm_hashes[modality] for modality in modalities
|
||||
if modality in mm_hashes
|
||||
]
|
||||
sorted_pairs = sorted(zip(placeholder_lists_with_modality,
|
||||
hashes_lists),
|
||||
key=lambda x: x[0][1][0]['offset'])
|
||||
sorted_placeholder_tuple, sorted_hash_tuple = zip(*sorted_pairs)
|
||||
sorted_placeholder_lists = list(sorted_placeholder_tuple)
|
||||
sorted_hash_lists = list(sorted_hash_tuple)
|
||||
|
||||
sorted_modalities = [modality for modality, _ in sorted_placeholder_lists]
|
||||
|
||||
# Flatten sorted list of lists to a single list and verify there is no
|
||||
# interleaving of placeholders from different modalities.
|
||||
merged_placeholders: list[PlaceholderRange] = []
|
||||
for modality, placeholder_list in sorted_placeholder_lists:
|
||||
if merged_placeholders and placeholder_list[0][
|
||||
'offset'] < merged_placeholders[-1]['offset']:
|
||||
raise ValueError(
|
||||
"Interleaved mixed-modality inference is currently not "
|
||||
"supported.")
|
||||
merged_placeholders.extend(placeholder_list)
|
||||
|
||||
if sorted_hash_lists is not None:
|
||||
merged_hashes = []
|
||||
for hash_list in sorted_hash_lists:
|
||||
merged_hashes.extend(hash_list)
|
||||
else:
|
||||
merged_hashes = None
|
||||
|
||||
return sorted_modalities, merged_placeholders, merged_hashes
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
|
||||
from vllm.sampling_params import SamplingParams
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -21,13 +23,13 @@ class EngineCoreRequest:
|
||||
# always be tokenized?
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: List[int]
|
||||
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
|
||||
mm_inputs: Optional[List[Optional["MultiModalKwargs"]]]
|
||||
mm_hashes: Optional[List[str]]
|
||||
mm_placeholders: Optional[MultiModalPlaceholderDict]
|
||||
sampling_params: SamplingParams
|
||||
mm_placeholders: Optional[List["PlaceholderRange"]]
|
||||
sampling_params: "SamplingParams"
|
||||
eos_token_id: Optional[int]
|
||||
arrival_time: float
|
||||
lora_request: Optional[LoRARequest]
|
||||
lora_request: Optional["LoRARequest"]
|
||||
|
||||
|
||||
class EngineCoreOutput(
|
||||
|
||||
@ -1,10 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import PIL
|
||||
from blake3 import blake3
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
@ -144,66 +140,3 @@ class MMInputMapperServer:
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
|
||||
|
||||
class MMHasher:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def hash_dummy_mm_data(
|
||||
self,
|
||||
mm_data: Optional[MultiModalDataDict]) -> Optional[List[str]]:
|
||||
"""Hash user-defined dummy multimodal data used for profiling."""
|
||||
|
||||
if mm_data is None:
|
||||
return None
|
||||
|
||||
image_inputs = mm_data['image']
|
||||
|
||||
# This is a temporary workaround for models (e.g, Molmo) that
|
||||
# process multimodal data in the input processor (therefore
|
||||
# image_inputs is MultiModalKwargs instead of raw input format).
|
||||
# `raw_mm_data` with the original input format is expected
|
||||
# in this case.
|
||||
if isinstance(image_inputs, dict):
|
||||
assert "raw_mm_data" in image_inputs and isinstance(
|
||||
image_inputs["raw_mm_data"], PIL.Image.Image)
|
||||
image_inputs = image_inputs.pop("raw_mm_data")
|
||||
|
||||
return self.hash_images(image_inputs)
|
||||
|
||||
def hash_prompt_mm_data(self, prompt: PromptType) -> Optional[List[str]]:
|
||||
"""Hash multimodal data in the user input prompt if they exist."""
|
||||
|
||||
if "multi_modal_data" not in prompt:
|
||||
return None
|
||||
|
||||
mm_data = prompt["multi_modal_data"]
|
||||
if not mm_data:
|
||||
# mm_data can be None or an empty dict.
|
||||
return None
|
||||
|
||||
image_inputs = mm_data["image"]
|
||||
|
||||
return self.hash_images(image_inputs)
|
||||
|
||||
def hash_images(self, image_inputs) -> Optional[List[str]]:
|
||||
"""Hash PIL image objects to strings."""
|
||||
if not isinstance(image_inputs, list):
|
||||
image_inputs = [image_inputs]
|
||||
assert len(image_inputs) > 0
|
||||
|
||||
ret = []
|
||||
for image in image_inputs:
|
||||
assert isinstance(image, PIL.Image.Image)
|
||||
|
||||
# Convert image to bytes
|
||||
bytes = image.tobytes()
|
||||
|
||||
# Hash image bytes
|
||||
hasher = blake3()
|
||||
hasher.update(bytes)
|
||||
ret.append(hasher.hexdigest())
|
||||
|
||||
return ret
|
||||
|
||||
@ -7,14 +7,15 @@ from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
MultiModalRegistry)
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalHasher,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
from vllm.multimodal.utils import merge_and_sort_multimodal_metadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
|
||||
|
||||
class Processor:
|
||||
@ -47,7 +48,6 @@ class Processor:
|
||||
# Multi-modal hasher (for images)
|
||||
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
||||
cache_config.enable_prefix_caching
|
||||
self.mm_hasher = MMHasher()
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
@ -73,11 +73,6 @@ class Processor:
|
||||
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
||||
assert trace_headers is None, "vLLM V1 does not support tracing yet."
|
||||
|
||||
# Compute MM hashes (if enabled)
|
||||
mm_hashes = None
|
||||
if self.use_hash:
|
||||
mm_hashes = self.mm_hasher.hash_prompt_mm_data(prompt)
|
||||
|
||||
# Process inputs.
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
@ -108,8 +103,20 @@ class Processor:
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
|
||||
# Multimodal related.
|
||||
# Compute MM hashes (if enabled)
|
||||
mm_hashes = None
|
||||
if self.use_hash:
|
||||
# Use mm_hashes from processed inputs if the model has merged
|
||||
# input processor.
|
||||
if decoder_inputs.multi_modal_hashes:
|
||||
mm_hashes = decoder_inputs.multi_modal_hashes
|
||||
# Fallback to using MultiModalHasher directly.
|
||||
else:
|
||||
mm_hashes = MultiModalHasher.hash_prompt_mm_data(prompt)
|
||||
|
||||
# For merged preprocessor, mm_data is already mm_inputs
|
||||
precomputed_mm_inputs = None
|
||||
precomputed_mm_inputs: Optional[list[MultiModalKwargs]] = None
|
||||
decoder_mm_data = decoder_inputs.multi_modal_data
|
||||
if isinstance(decoder_mm_data, MultiModalKwargs):
|
||||
# The output of merged multi-modal processor (`decoder_mm_data`)
|
||||
@ -122,27 +129,67 @@ class Processor:
|
||||
for item in decoder_mm_data.get_items(modality)
|
||||
]
|
||||
|
||||
# Apply MM mapper
|
||||
mm_inputs = None
|
||||
if len(decoder_mm_data) > 0:
|
||||
mm_inputs = self.mm_input_mapper_client.process_inputs(
|
||||
decoder_mm_data,
|
||||
mm_positions = decoder_inputs.multi_modal_placeholders
|
||||
|
||||
# Last-mile processing of multimodal metadata and inputs.
|
||||
if mm_positions:
|
||||
|
||||
# Merge and flatten multimodal placeholders, hashes and inputs
|
||||
# from dictionaries to lists, and sort them by each item's position
|
||||
# in the input sequence.
|
||||
# NOTE: interleaved modalities are not supported.
|
||||
(
|
||||
sorted_modalities,
|
||||
sorted_mm_positions,
|
||||
sorted_mm_hashes,
|
||||
) = merge_and_sort_multimodal_metadata(
|
||||
mm_positions,
|
||||
mm_hashes,
|
||||
decoder_inputs.mm_processor_kwargs,
|
||||
precomputed_mm_inputs,
|
||||
)
|
||||
|
||||
# NOTE: Sort multimodal inputs/kwargs ONLY IF there are multiple
|
||||
# modalities involved AND the model supports merged input processor.
|
||||
if len(sorted_modalities) > 1 and precomputed_mm_inputs:
|
||||
|
||||
modality_order_dict = {
|
||||
modality: order
|
||||
for order, modality in enumerate(sorted_modalities)
|
||||
}
|
||||
|
||||
# Sanity check to make sure each multimodal input has only one
|
||||
# modality key.
|
||||
for mm_input in precomputed_mm_inputs:
|
||||
assert len(mm_input.modalities) == 1
|
||||
|
||||
# Sort MultiModalKwags to match sorted_mm_positions
|
||||
precomputed_mm_inputs = sorted(
|
||||
precomputed_mm_inputs,
|
||||
key=lambda mm_input: modality_order_dict[list(
|
||||
mm_input.modalities)[0]])
|
||||
|
||||
# Apply mm input cache update (and input mapper if necessary).
|
||||
sorted_mm_inputs = self.mm_input_mapper_client.process_inputs(
|
||||
mm_data=decoder_mm_data,
|
||||
mm_hashes=sorted_mm_hashes,
|
||||
mm_processor_kwargs=decoder_inputs.mm_processor_kwargs,
|
||||
precomputed_mm_inputs=precomputed_mm_inputs,
|
||||
)
|
||||
else:
|
||||
sorted_mm_inputs = None
|
||||
sorted_mm_hashes = None
|
||||
sorted_mm_positions = None
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id,
|
||||
decoder_inputs.prompt,
|
||||
decoder_inputs.prompt_token_ids,
|
||||
mm_inputs,
|
||||
mm_hashes,
|
||||
decoder_inputs.multi_modal_placeholders,
|
||||
sampling_params,
|
||||
eos_token_id,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
request_id=request_id,
|
||||
prompt=decoder_inputs.prompt,
|
||||
prompt_token_ids=decoder_inputs.prompt_token_ids,
|
||||
mm_inputs=sorted_mm_inputs,
|
||||
mm_hashes=sorted_mm_hashes,
|
||||
mm_placeholders=sorted_mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
eos_token_id=eos_token_id,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs):
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from vllm.inputs import DecoderOnlyInputs, SingletonInputsAdapter, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import RequestMetrics
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashType
|
||||
|
||||
|
||||
@ -18,14 +18,17 @@ class Request:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: DecoderOnlyInputs,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
multi_modal_inputs: Optional[List["MultiModalKwargs"]],
|
||||
multi_modal_hashes: Optional[List[str]],
|
||||
multi_modal_placeholders: Optional[List["PlaceholderRange"]],
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_id: Optional[int],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.inputs = SingletonInputsAdapter(inputs)
|
||||
self.sampling_params = sampling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
@ -41,26 +44,21 @@ class Request:
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
|
||||
self.prompt = self.inputs.prompt
|
||||
self.prompt_token_ids = self.inputs.prompt_token_ids
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self._output_token_ids: List[int] = []
|
||||
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
# Multi-modal input metadata.
|
||||
mm_positions = self.inputs.multi_modal_placeholders
|
||||
if mm_positions:
|
||||
# FIXME(woosuk): Support other modalities.
|
||||
self.mm_positions = mm_positions.get("image", [])
|
||||
else:
|
||||
self.mm_positions = []
|
||||
# Output of the mm input mapper (e.g., image tensors).
|
||||
self.mm_inputs: List[MultiModalKwargs] = []
|
||||
if self.inputs.multi_modal_inputs:
|
||||
self.mm_inputs = self.inputs.multi_modal_inputs
|
||||
# Multi-modal related
|
||||
self.mm_positions = multi_modal_placeholders or []
|
||||
self.mm_inputs = multi_modal_inputs or []
|
||||
self.mm_hashes: List[str] = multi_modal_hashes or []
|
||||
|
||||
self.mm_hashes: List[str] = self.inputs.multi_modal_hashes
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
assert len(self.mm_inputs) == len(self.mm_hashes)
|
||||
|
||||
# Cache the computed kv block hashes of the request to avoid
|
||||
# recomputing.
|
||||
@ -70,15 +68,11 @@ class Request:
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
inputs=token_inputs(
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt=request.prompt,
|
||||
multi_modal_data=None,
|
||||
multi_modal_inputs=request.mm_inputs,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
mm_processor_kwargs=None,
|
||||
),
|
||||
prompt=request.prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
multi_modal_inputs=request.mm_inputs,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
sampling_params=request.sampling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
arrival_time=request.arrival_time,
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -82,12 +82,10 @@ class GPUModelRunner:
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
|
||||
# NOTE: mm_input_mapper_client and mm_hasher are only used for memory
|
||||
# profiling.
|
||||
self.mm_input_mapper_client = MMInputMapperClient(self.model_config)
|
||||
self.mm_hasher = MMHasher()
|
||||
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
||||
cache_config.enable_prefix_caching
|
||||
# NOTE: Initialized input mapper is only used for processing dummy
|
||||
# multimodal data into multimodal kwargs for GPU memory profiling.
|
||||
self.mm_input_mapper_profiling = MMInputMapperClient(self.model_config)
|
||||
self.mm_input_mapper_profiling.use_cache = False
|
||||
|
||||
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
|
||||
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
|
||||
@ -722,8 +720,6 @@ class GPUModelRunner:
|
||||
]
|
||||
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
# TODO (ywang96): generalize this beyond image modality since
|
||||
# mm_input_mapper only supports image inputs.
|
||||
if self.is_multimodal_model:
|
||||
|
||||
# Create dummy batch of multimodal inputs.
|
||||
@ -735,15 +731,30 @@ class GPUModelRunner:
|
||||
dummy_mm_data = dummy_request_data.multi_modal_data
|
||||
|
||||
# NOTE: Currently model is profiled with a single non-text
|
||||
# modality even when it supports multiple.
|
||||
max_tokens_per_mm_item = max(
|
||||
self.mm_registry.get_max_tokens_per_item_by_modality(
|
||||
self.model_config).values())
|
||||
# modality with the max possible input tokens even when
|
||||
# it supports multiple.
|
||||
max_tokens_by_modality_dict = self.mm_registry.get_max_tokens_per_item_by_modality( # noqa: E501
|
||||
self.model_config)
|
||||
|
||||
max_num_mm_items_encoder_budget = min(
|
||||
self.max_num_encoder_input_tokens,
|
||||
self.encoder_cache_size) // max_tokens_per_mm_item
|
||||
dummy_data_modality, max_tokens_per_mm_item = max(
|
||||
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the encoder cache budget.
|
||||
encoder_cache_budget = min(self.max_num_encoder_input_tokens,
|
||||
self.encoder_cache_size)
|
||||
max_num_mm_items_encoder_budget = encoder_cache_budget // \
|
||||
max_tokens_per_mm_item
|
||||
|
||||
# TODO: Allow users to set encoder_cache_budget in case this
|
||||
# happens.
|
||||
assert max_num_mm_items_encoder_budget > 0, (
|
||||
f"Encoder cache budget={encoder_cache_budget} is too small to "
|
||||
f"support the maximum possible size of multimodal embeddings"
|
||||
f"={max_tokens_per_mm_item}.")
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the decoder budget.
|
||||
max_mm_items_per_req = max(
|
||||
self.mm_registry.get_mm_limits_per_prompt(
|
||||
self.model_config).values())
|
||||
@ -763,33 +774,24 @@ class GPUModelRunner:
|
||||
# they are scheduled to be processed separately.
|
||||
|
||||
# Case when models have a merged processor, their dummy data is
|
||||
# already batched `MultiModalKwargs`, therefore we need to "unbatch"
|
||||
# and take the first item in each batched tensor.
|
||||
# TODO (ywang96): This is somewhat hacky. Refactor this to be
|
||||
# consistent with the other case.
|
||||
# already batched `MultiModalKwargs`, therefore we take the first
|
||||
# `MultiModalKwargsItem` from the desired modality to profile on.
|
||||
if isinstance(dummy_mm_data, MultiModalKwargs):
|
||||
dummy_mm_kwargs = {
|
||||
k: v[0].unsqueeze(0)
|
||||
for k, v in dummy_mm_data.items()
|
||||
}
|
||||
dummy_mm_item = dummy_mm_data.get_item(
|
||||
modality=dummy_data_modality, item_index=0)
|
||||
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
|
||||
|
||||
# Case when models have dummy data explicitly defined as
|
||||
# `MultiModalDataDict`, so they need to be processed through input
|
||||
# mapper.
|
||||
# TODO (ywang96): deprecate this path once merged processor is
|
||||
# supported on all models.
|
||||
else:
|
||||
# Compute MM hashes (if enabled)
|
||||
mm_hashes = None
|
||||
if self.use_hash:
|
||||
mm_hashes = self.mm_hasher.hash_dummy_mm_data(
|
||||
dummy_mm_data)
|
||||
|
||||
mm_kwargs_list = self.mm_input_mapper_client.process_inputs(
|
||||
mm_kwargs_list = self.mm_input_mapper_profiling.process_inputs(
|
||||
mm_data=dummy_mm_data,
|
||||
mm_hashes=mm_hashes,
|
||||
mm_hashes=None,
|
||||
mm_processor_kwargs=None,
|
||||
precomputed_mm_inputs=None)
|
||||
|
||||
# Take the first `MultiModalKwargs`
|
||||
dummy_mm_kwargs = mm_kwargs_list[0]
|
||||
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user