From 60c3d413afccab6a1f9a18cf3cd1fe11019c1040 Mon Sep 17 00:00:00 2001 From: ImaGoodFella <31959740+ImaGoodFella@users.noreply.github.com> Date: Tue, 2 Dec 2025 14:49:02 +0100 Subject: [PATCH] [Multimodal][Core] Optimize multimodal preprocessing cache by hashing image bytes instead of pixel values (#29621) Signed-off-by: Rahul Steiger Co-authored-by: Cyrus Leung --- tests/conftest.py | 7 ++++- tests/entrypoints/openai/test_vision.py | 7 ++++- .../pooling/embed/test_online_vision.py | 7 ++++- vllm/multimodal/base.py | 28 +++++++++++++++++++ vllm/multimodal/hasher.py | 24 ++++++++++++---- vllm/multimodal/image.py | 24 +++++++++------- vllm/multimodal/parse.py | 15 ++++++++++ vllm/multimodal/processing.py | 2 +- 8 files changed, 95 insertions(+), 19 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 53bbaddd0bb7f..b20c9efef542a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,7 @@ from vllm.distributed import ( ) from vllm.logger import init_logger from vllm.logprobs import Logprob +from vllm.multimodal.base import MediaWithBytes from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams @@ -1389,7 +1390,11 @@ class LocalAssetServer: return f"{self.base_url}/{name}" def get_image_asset(self, name: str) -> Image.Image: - return fetch_image(self.url_for(name)) + image = fetch_image(self.url_for(name)) + # Unwrap MediaWithBytes if present + if isinstance(image, MediaWithBytes): + image = image.media + return image @pytest.fixture(scope="session") diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index d83c6726e72da..ae8860ee877b4 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -8,6 +8,7 @@ import pytest import pytest_asyncio from transformers import AutoProcessor +from vllm.multimodal.base import MediaWithBytes from vllm.multimodal.utils import encode_image_base64, fetch_image from ...utils import RemoteOpenAIServer @@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url): "content": f"{placeholder}{content}", } ] - images = [fetch_image(image_url)] + image = fetch_image(image_url) + # Unwrap MediaWithBytes if present + if isinstance(image, MediaWithBytes): + image = image.media + images = [image] prompt = processor.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True diff --git a/tests/entrypoints/pooling/embed/test_online_vision.py b/tests/entrypoints/pooling/embed/test_online_vision.py index 83e7048b9def6..eebbcdd2e4396 100644 --- a/tests/entrypoints/pooling/embed/test_online_vision.py +++ b/tests/entrypoints/pooling/embed/test_online_vision.py @@ -9,6 +9,7 @@ from transformers import AutoProcessor from tests.utils import VLLM_PATH, RemoteOpenAIServer from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse +from vllm.multimodal.base import MediaWithBytes from vllm.multimodal.utils import encode_image_base64, fetch_image MODEL_NAME = "TIGER-Lab/VLM2Vec-Full" @@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url): placeholder = "<|image_1|> " prompt = f"{placeholder}{content}" - images = [fetch_image(image_url)] + image = fetch_image(image_url) + # Unwrap MediaWithBytes if present + if isinstance(image, MediaWithBytes): + image = image.media + images = [image] inputs = processor(prompt, images, return_tensors="pt") return inputs.input_ids.shape[1] diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index fef118a93c6cb..4a619fd303ca9 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -2,12 +2,40 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Generic, TypeVar +import numpy as np + _T = TypeVar("_T") +@dataclass +class MediaWithBytes(Generic[_T]): + """ + Wrapper that couples a media object with its original encoded bytes. + + This ensures the raw bytes and media object remain synchronized, + preventing cache corruption from in-place modifications. + + The wrapper delegates attribute access to the underlying media object, + making it behave transparently like the wrapped type (e.g., PIL.Image). + """ + + media: _T + original_bytes: bytes + + def __array__(self, *args, **kwargs) -> np.ndarray: + """Allow np.array(obj) to return np.array(obj.media).""" + return np.array(self.media, *args, **kwargs) + + def __getattr__(self, name: str): + """Delegate attribute access to the underlying media object.""" + # This is only called when the attribute is not found on self + return getattr(self.media, name) + + class MediaIO(ABC, Generic[_T]): @abstractmethod def load_bytes(self, data: bytes) -> _T: diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index d0dcbb25fcce8..cc50322fed902 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -12,6 +12,8 @@ from PIL import Image from vllm.logger import init_logger +from .base import MediaWithBytes + logger = init_logger(__name__) @@ -31,14 +33,26 @@ class MultiModalHasher: if Image.ExifTags.Base.ImageID in exif and isinstance( exif[Image.ExifTags.Base.ImageID], uuid.UUID ): - # If the image has exif ImageID tag, use that return (exif[Image.ExifTags.Base.ImageID].bytes,) + data = {"mode": obj.mode, "data": np.asarray(obj)} - if obj.palette is not None: - data["palette"] = obj.palette.palette - if obj.palette.rawmode is not None: - data["palette_rawmode"] = obj.palette.rawmode + palette = obj.palette + if palette is not None: + data["palette"] = palette.palette + if palette.rawmode is not None: + data["palette_rawmode"] = palette.rawmode + return cls.iter_item_to_bytes("image", data) + + if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image): + exif = obj.media.getexif() + if Image.ExifTags.Base.ImageID in exif and isinstance( + exif[Image.ExifTags.Base.ImageID], uuid.UUID + ): + return (exif[Image.ExifTags.Base.ImageID].bytes,) + + return cls.iter_item_to_bytes("image", obj.original_bytes) + if isinstance(obj, torch.Tensor): tensor_obj: torch.Tensor = obj.cpu() tensor_dtype = tensor_obj.dtype diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 21e8bef97a787..789421e9e0c3b 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -8,7 +8,7 @@ import pybase64 import torch from PIL import Image -from .base import MediaIO +from .base import MediaIO, MediaWithBytes def rescale_image_size( @@ -74,8 +74,12 @@ class ImageMediaIO(MediaIO[Image.Image]): ) self.rgba_background_color = rgba_bg - def _convert_image_mode(self, image: Image.Image) -> Image.Image: + def _convert_image_mode( + self, image: Image.Image | MediaWithBytes[Image.Image] + ) -> Image.Image: """Convert image mode with custom background color.""" + if isinstance(image, MediaWithBytes): + image = image.media if image.mode == self.image_mode: return image elif image.mode == "RGBA" and self.image_mode == "RGB": @@ -83,18 +87,18 @@ class ImageMediaIO(MediaIO[Image.Image]): else: return convert_image_mode(image, self.image_mode) - def load_bytes(self, data: bytes) -> Image.Image: + def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]: image = Image.open(BytesIO(data)) - image.load() - return self._convert_image_mode(image) + return MediaWithBytes(self._convert_image_mode(image), data) - def load_base64(self, media_type: str, data: str) -> Image.Image: + def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]: return self.load_bytes(pybase64.b64decode(data, validate=True)) - def load_file(self, filepath: Path) -> Image.Image: - image = Image.open(filepath) - image.load() - return self._convert_image_mode(image) + def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]: + with open(filepath, "rb") as f: + data = f.read() + image = Image.open(BytesIO(data)) + return MediaWithBytes(self._convert_image_mode(image), data) def encode_base64( self, diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 810f29072a0fe..0d3b8289e4e12 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -23,6 +23,7 @@ from vllm.utils.collection_utils import is_list_of from vllm.utils.import_utils import LazyLoader from .audio import AudioResampler +from .base import MediaWithBytes from .inputs import ( AudioItem, HfAudioItem, @@ -84,6 +85,12 @@ class ModalityDataItems(ABC, Generic[_T, _I]): """Get all data items.""" return [self.get(idx) for idx in range(self.get_count())] + def get_item_for_hash(self, index: int) -> object: + return self.get(index) + + def get_all_items_for_hash(self) -> list[object]: + return [self.get_item_for_hash(idx) for idx in range(self.get_count())] + @abstractmethod def get_processor_data(self) -> Mapping[str, object]: """Get the data to pass to the HF processor.""" @@ -98,10 +105,18 @@ class ModalityDataItems(ABC, Generic[_T, _I]): class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]): """Base class for data items that are arranged in a list.""" + def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T: + """Extract media from wrapper if present.""" + return item.media if isinstance(item, MediaWithBytes) else item + def get_count(self) -> int: return len(self.data) def get(self, index: int) -> _T: + return self._unwrap(self.data[index]) + + def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]: + # Return raw item for hashing (preserves original_bytes if present) return self.data[index] def get_processor_data(self) -> Mapping[str, object]: diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index f241e79cfa7cb..0390773783961 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1684,7 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): # For None entries, compute a hash; otherwise, use provided ID. computed: list[str] = [] - for i, item in enumerate(items): + for i, item in enumerate(items.get_all_items_for_hash()): item_uuid = mm_uuids_per_modality[i] # NOTE: Even if a item_uuid is provided, we still compute a