[Multimodal][Core] Optimize multimodal preprocessing cache by hashing image bytes instead of pixel values (#29621)

Signed-off-by: Rahul Steiger <rasteiger@ethz.ch>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
ImaGoodFella 2025-12-02 14:49:02 +01:00 committed by GitHub
parent 68ffbca7e4
commit 60c3d413af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 95 additions and 19 deletions

View File

@ -59,6 +59,7 @@ from vllm.distributed import (
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
@ -1389,7 +1390,11 @@ class LocalAssetServer:
return f"{self.base_url}/{name}" return f"{self.base_url}/{name}"
def get_image_asset(self, name: str) -> Image.Image: def get_image_asset(self, name: str) -> Image.Image:
return fetch_image(self.url_for(name)) image = fetch_image(self.url_for(name))
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
return image
@pytest.fixture(scope="session") @pytest.fixture(scope="session")

View File

@ -8,6 +8,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from transformers import AutoProcessor from transformers import AutoProcessor
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
"content": f"{placeholder}{content}", "content": f"{placeholder}{content}",
} }
] ]
images = [fetch_image(image_url)] image = fetch_image(image_url)
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
images = [image]
prompt = processor.tokenizer.apply_chat_template( prompt = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True

View File

@ -9,6 +9,7 @@ from transformers import AutoProcessor
from tests.utils import VLLM_PATH, RemoteOpenAIServer from tests.utils import VLLM_PATH, RemoteOpenAIServer
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import encode_image_base64, fetch_image from vllm.multimodal.utils import encode_image_base64, fetch_image
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
placeholder = "<|image_1|> " placeholder = "<|image_1|> "
prompt = f"{placeholder}{content}" prompt = f"{placeholder}{content}"
images = [fetch_image(image_url)] image = fetch_image(image_url)
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
images = [image]
inputs = processor(prompt, images, return_tensors="pt") inputs = processor(prompt, images, return_tensors="pt")
return inputs.input_ids.shape[1] return inputs.input_ids.shape[1]

View File

@ -2,12 +2,40 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generic, TypeVar from typing import Generic, TypeVar
import numpy as np
_T = TypeVar("_T") _T = TypeVar("_T")
@dataclass
class MediaWithBytes(Generic[_T]):
"""
Wrapper that couples a media object with its original encoded bytes.
This ensures the raw bytes and media object remain synchronized,
preventing cache corruption from in-place modifications.
The wrapper delegates attribute access to the underlying media object,
making it behave transparently like the wrapped type (e.g., PIL.Image).
"""
media: _T
original_bytes: bytes
def __array__(self, *args, **kwargs) -> np.ndarray:
"""Allow np.array(obj) to return np.array(obj.media)."""
return np.array(self.media, *args, **kwargs)
def __getattr__(self, name: str):
"""Delegate attribute access to the underlying media object."""
# This is only called when the attribute is not found on self
return getattr(self.media, name)
class MediaIO(ABC, Generic[_T]): class MediaIO(ABC, Generic[_T]):
@abstractmethod @abstractmethod
def load_bytes(self, data: bytes) -> _T: def load_bytes(self, data: bytes) -> _T:

View File

@ -12,6 +12,8 @@ from PIL import Image
from vllm.logger import init_logger from vllm.logger import init_logger
from .base import MediaWithBytes
logger = init_logger(__name__) logger = init_logger(__name__)
@ -31,14 +33,26 @@ class MultiModalHasher:
if Image.ExifTags.Base.ImageID in exif and isinstance( if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID exif[Image.ExifTags.Base.ImageID], uuid.UUID
): ):
# If the image has exif ImageID tag, use that
return (exif[Image.ExifTags.Base.ImageID].bytes,) return (exif[Image.ExifTags.Base.ImageID].bytes,)
data = {"mode": obj.mode, "data": np.asarray(obj)} data = {"mode": obj.mode, "data": np.asarray(obj)}
if obj.palette is not None: palette = obj.palette
data["palette"] = obj.palette.palette if palette is not None:
if obj.palette.rawmode is not None: data["palette"] = palette.palette
data["palette_rawmode"] = obj.palette.rawmode if palette.rawmode is not None:
data["palette_rawmode"] = palette.rawmode
return cls.iter_item_to_bytes("image", data) return cls.iter_item_to_bytes("image", data)
if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
exif = obj.media.getexif()
if Image.ExifTags.Base.ImageID in exif and isinstance(
exif[Image.ExifTags.Base.ImageID], uuid.UUID
):
return (exif[Image.ExifTags.Base.ImageID].bytes,)
return cls.iter_item_to_bytes("image", obj.original_bytes)
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
tensor_obj: torch.Tensor = obj.cpu() tensor_obj: torch.Tensor = obj.cpu()
tensor_dtype = tensor_obj.dtype tensor_dtype = tensor_obj.dtype

View File

@ -8,7 +8,7 @@ import pybase64
import torch import torch
from PIL import Image from PIL import Image
from .base import MediaIO from .base import MediaIO, MediaWithBytes
def rescale_image_size( def rescale_image_size(
@ -74,8 +74,12 @@ class ImageMediaIO(MediaIO[Image.Image]):
) )
self.rgba_background_color = rgba_bg self.rgba_background_color = rgba_bg
def _convert_image_mode(self, image: Image.Image) -> Image.Image: def _convert_image_mode(
self, image: Image.Image | MediaWithBytes[Image.Image]
) -> Image.Image:
"""Convert image mode with custom background color.""" """Convert image mode with custom background color."""
if isinstance(image, MediaWithBytes):
image = image.media
if image.mode == self.image_mode: if image.mode == self.image_mode:
return image return image
elif image.mode == "RGBA" and self.image_mode == "RGB": elif image.mode == "RGBA" and self.image_mode == "RGB":
@ -83,18 +87,18 @@ class ImageMediaIO(MediaIO[Image.Image]):
else: else:
return convert_image_mode(image, self.image_mode) return convert_image_mode(image, self.image_mode)
def load_bytes(self, data: bytes) -> Image.Image: def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
image = Image.open(BytesIO(data)) image = Image.open(BytesIO(data))
image.load() return MediaWithBytes(self._convert_image_mode(image), data)
return self._convert_image_mode(image)
def load_base64(self, media_type: str, data: str) -> Image.Image: def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
return self.load_bytes(pybase64.b64decode(data, validate=True)) return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> Image.Image: def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
image = Image.open(filepath) with open(filepath, "rb") as f:
image.load() data = f.read()
return self._convert_image_mode(image) image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
def encode_base64( def encode_base64(
self, self,

View File

@ -23,6 +23,7 @@ from vllm.utils.collection_utils import is_list_of
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from .audio import AudioResampler from .audio import AudioResampler
from .base import MediaWithBytes
from .inputs import ( from .inputs import (
AudioItem, AudioItem,
HfAudioItem, HfAudioItem,
@ -84,6 +85,12 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
"""Get all data items.""" """Get all data items."""
return [self.get(idx) for idx in range(self.get_count())] return [self.get(idx) for idx in range(self.get_count())]
def get_item_for_hash(self, index: int) -> object:
return self.get(index)
def get_all_items_for_hash(self) -> list[object]:
return [self.get_item_for_hash(idx) for idx in range(self.get_count())]
@abstractmethod @abstractmethod
def get_processor_data(self) -> Mapping[str, object]: def get_processor_data(self) -> Mapping[str, object]:
"""Get the data to pass to the HF processor.""" """Get the data to pass to the HF processor."""
@ -98,10 +105,18 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
"""Base class for data items that are arranged in a list.""" """Base class for data items that are arranged in a list."""
def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
"""Extract media from wrapper if present."""
return item.media if isinstance(item, MediaWithBytes) else item
def get_count(self) -> int: def get_count(self) -> int:
return len(self.data) return len(self.data)
def get(self, index: int) -> _T: def get(self, index: int) -> _T:
return self._unwrap(self.data[index])
def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
# Return raw item for hashing (preserves original_bytes if present)
return self.data[index] return self.data[index]
def get_processor_data(self) -> Mapping[str, object]: def get_processor_data(self) -> Mapping[str, object]:

View File

@ -1684,7 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# For None entries, compute a hash; otherwise, use provided ID. # For None entries, compute a hash; otherwise, use provided ID.
computed: list[str] = [] computed: list[str] = []
for i, item in enumerate(items): for i, item in enumerate(items.get_all_items_for_hash()):
item_uuid = mm_uuids_per_modality[i] item_uuid = mm_uuids_per_modality[i]
# NOTE: Even if a item_uuid is provided, we still compute a # NOTE: Even if a item_uuid is provided, we still compute a