[Multimodal][Core] Optimize multimodal preprocessing cache by hashing image bytes instead of pixel values (#29621)

Signed-off-by: Rahul Steiger <rasteiger@ethz.ch>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
ImaGoodFella 2025-12-02 14:49:02 +01:00 committed by GitHub
parent 68ffbca7e4
commit 60c3d413af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 95 additions and 19 deletions

View File

@ -59,6 +59,7 @@ from vllm.distributed import (
)
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
@ -1389,7 +1390,11 @@ class LocalAssetServer:
return f"{self.base_url}/{name}"
def get_image_asset(self, name: str) -> Image.Image:
return fetch_image(self.url_for(name))
image = fetch_image(self.url_for(name))
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
return image
@pytest.fixture(scope="session")

View File

@ -8,6 +8,7 @@ import pytest
import pytest_asyncio
from transformers import AutoProcessor
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image
from ...utils import RemoteOpenAIServer
@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
"content": f"{placeholder}{content}",
}
]
images = [fetch_image(image_url)]
image = fetch_image(image_url)
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
images = [image]
prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True

View File

@ -9,6 +9,7 @@ from transformers import AutoProcessor
from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
placeholder = "<|image_1|> "
prompt = f"{placeholder}{content}"
images = [fetch_image(image_url)]
image = fetch_image(image_url)
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
images = [image]
inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1]

View File

@ -2,12 +2,40 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, TypeVar
import numpy as np
_T = TypeVar("_T")
@dataclass
class MediaWithBytes(Generic[_T]):
"""
Wrapper that couples a media object with its original encoded bytes.
This ensures the raw bytes and media object remain synchronized,
preventing cache corruption from in-place modifications.
The wrapper delegates attribute access to the underlying media object,
making it behave transparently like the wrapped type (e.g., PIL.Image).
"""
media: _T
original_bytes: bytes
def __array__(self, *args, **kwargs) -> np.ndarray:
"""Allow np.array(obj) to return np.array(obj.media)."""
return np.array(self.media, *args, **kwargs)
def __getattr__(self, name: str):
"""Delegate attribute access to the underlying media object."""
# This is only called when the attribute is not found on self
return getattr(self.media, name)
class MediaIO(ABC, Generic[_T]):
@abstractmethod
def load_bytes(self, data: bytes) -> _T:

View File

@ -12,6 +12,8 @@ from PIL import Image
from vllm.logger import init_logger
from .base import MediaWithBytes
logger = init_logger(__name__)
@ -31,14 +33,26 @@ class MultiModalHasher:
if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID
):
# If the image has exif ImageID tag, use that
return (exif[Image.ExifTags.Base.ImageID].bytes,)
data = {"mode": obj.mode, "data": np.asarray(obj)}
if obj.palette is not None:
data["palette"] = obj.palette.palette
if obj.palette.rawmode is not None:
data["palette_rawmode"] = obj.palette.rawmode
palette = obj.palette
if palette is not None:
data["palette"] = palette.palette
if palette.rawmode is not None:
data["palette_rawmode"] = palette.rawmode
return cls.iter_item_to_bytes("image", data)
if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
exif = obj.media.getexif()
if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID
):
return (exif[Image.ExifTags.Base.ImageID].bytes,)
return cls.iter_item_to_bytes("image", obj.original_bytes)
if isinstance(obj, torch.Tensor):
tensor_obj: torch.Tensor = obj.cpu()
tensor_dtype = tensor_obj.dtype

View File

@ -8,7 +8,7 @@ import pybase64
import torch
from PIL import Image
from .base import MediaIO
from .base import MediaIO, MediaWithBytes
def rescale_image_size(
@ -74,8 +74,12 @@ class ImageMediaIO(MediaIO[Image.Image]):
)
self.rgba_background_color = rgba_bg
def _convert_image_mode(self, image: Image.Image) -> Image.Image:
def _convert_image_mode(
self, image: Image.Image | MediaWithBytes[Image.Image]
) -> Image.Image:
"""Convert image mode with custom background color."""
if isinstance(image, MediaWithBytes):
image = image.media
if image.mode == self.image_mode:
return image
elif image.mode == "RGBA" and self.image_mode == "RGB":
@ -83,18 +87,18 @@ class ImageMediaIO(MediaIO[Image.Image]):
else:
return convert_image_mode(image, self.image_mode)
def load_bytes(self, data: bytes) -> Image.Image:
def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
image = Image.open(BytesIO(data))
image.load()
return self._convert_image_mode(image)
return MediaWithBytes(self._convert_image_mode(image), data)
def load_base64(self, media_type: str, data: str) -> Image.Image:
def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return self._convert_image_mode(image)
def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
with open(filepath, "rb") as f:
data = f.read()
image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
def encode_base64(
self,

View File

@ -23,6 +23,7 @@ from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader
from .audio import AudioResampler
from .base import MediaWithBytes
from .inputs import (
AudioItem,
HfAudioItem,
@ -84,6 +85,12 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
"""Get all data items."""
return [self.get(idx) for idx in range(self.get_count())]
def get_item_for_hash(self, index: int) -> object:
return self.get(index)
def get_all_items_for_hash(self) -> list[object]:
return [self.get_item_for_hash(idx) for idx in range(self.get_count())]
@abstractmethod
def get_processor_data(self) -> Mapping[str, object]:
"""Get the data to pass to the HF processor."""
@ -98,10 +105,18 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
"""Base class for data items that are arranged in a list."""
def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
"""Extract media from wrapper if present."""
return item.media if isinstance(item, MediaWithBytes) else item
def get_count(self) -> int:
return len(self.data)
def get(self, index: int) -> _T:
return self._unwrap(self.data[index])
def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
# Return raw item for hashing (preserves original_bytes if present)
return self.data[index]
def get_processor_data(self) -> Mapping[str, object]:

View File

@ -1684,7 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# For None entries, compute a hash; otherwise, use provided ID.
computed: list[str] = []
for i, item in enumerate(items):
for i, item in enumerate(items.get_all_items_for_hash()):
item_uuid = mm_uuids_per_modality[i]
# NOTE: Even if a item_uuid is provided, we still compute a