mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:44:30 +08:00
[Multimodal][Core] Optimize multimodal preprocessing cache by hashing image bytes instead of pixel values (#29621)
Signed-off-by: Rahul Steiger <rasteiger@ethz.ch> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
68ffbca7e4
commit
60c3d413af
@ -59,6 +59,7 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.multimodal.base import MediaWithBytes
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
@ -1389,7 +1390,11 @@ class LocalAssetServer:
|
||||
return f"{self.base_url}/{name}"
|
||||
|
||||
def get_image_asset(self, name: str) -> Image.Image:
|
||||
return fetch_image(self.url_for(name))
|
||||
image = fetch_image(self.url_for(name))
|
||||
# Unwrap MediaWithBytes if present
|
||||
if isinstance(image, MediaWithBytes):
|
||||
image = image.media
|
||||
return image
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm.multimodal.base import MediaWithBytes
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
|
||||
"content": f"{placeholder}{content}",
|
||||
}
|
||||
]
|
||||
images = [fetch_image(image_url)]
|
||||
image = fetch_image(image_url)
|
||||
# Unwrap MediaWithBytes if present
|
||||
if isinstance(image, MediaWithBytes):
|
||||
image = image.media
|
||||
images = [image]
|
||||
|
||||
prompt = processor.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
|
||||
@ -9,6 +9,7 @@ from transformers import AutoProcessor
|
||||
|
||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||
from vllm.multimodal.base import MediaWithBytes
|
||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||
|
||||
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
||||
@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
|
||||
|
||||
placeholder = "<|image_1|> "
|
||||
prompt = f"{placeholder}{content}"
|
||||
images = [fetch_image(image_url)]
|
||||
image = fetch_image(image_url)
|
||||
# Unwrap MediaWithBytes if present
|
||||
if isinstance(image, MediaWithBytes):
|
||||
image = image.media
|
||||
images = [image]
|
||||
inputs = processor(prompt, images, return_tensors="pt")
|
||||
return inputs.input_ids.shape[1]
|
||||
|
||||
|
||||
@ -2,12 +2,40 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MediaWithBytes(Generic[_T]):
|
||||
"""
|
||||
Wrapper that couples a media object with its original encoded bytes.
|
||||
|
||||
This ensures the raw bytes and media object remain synchronized,
|
||||
preventing cache corruption from in-place modifications.
|
||||
|
||||
The wrapper delegates attribute access to the underlying media object,
|
||||
making it behave transparently like the wrapped type (e.g., PIL.Image).
|
||||
"""
|
||||
|
||||
media: _T
|
||||
original_bytes: bytes
|
||||
|
||||
def __array__(self, *args, **kwargs) -> np.ndarray:
|
||||
"""Allow np.array(obj) to return np.array(obj.media)."""
|
||||
return np.array(self.media, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Delegate attribute access to the underlying media object."""
|
||||
# This is only called when the attribute is not found on self
|
||||
return getattr(self.media, name)
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
@abstractmethod
|
||||
def load_bytes(self, data: bytes) -> _T:
|
||||
|
||||
@ -12,6 +12,8 @@ from PIL import Image
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base import MediaWithBytes
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -31,14 +33,26 @@ class MultiModalHasher:
|
||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||
):
|
||||
# If the image has exif ImageID tag, use that
|
||||
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||
|
||||
data = {"mode": obj.mode, "data": np.asarray(obj)}
|
||||
if obj.palette is not None:
|
||||
data["palette"] = obj.palette.palette
|
||||
if obj.palette.rawmode is not None:
|
||||
data["palette_rawmode"] = obj.palette.rawmode
|
||||
palette = obj.palette
|
||||
if palette is not None:
|
||||
data["palette"] = palette.palette
|
||||
if palette.rawmode is not None:
|
||||
data["palette_rawmode"] = palette.rawmode
|
||||
|
||||
return cls.iter_item_to_bytes("image", data)
|
||||
|
||||
if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
|
||||
exif = obj.media.getexif()
|
||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||
):
|
||||
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||
|
||||
return cls.iter_item_to_bytes("image", obj.original_bytes)
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor_obj: torch.Tensor = obj.cpu()
|
||||
tensor_dtype = tensor_obj.dtype
|
||||
|
||||
@ -8,7 +8,7 @@ import pybase64
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from .base import MediaIO
|
||||
from .base import MediaIO, MediaWithBytes
|
||||
|
||||
|
||||
def rescale_image_size(
|
||||
@ -74,8 +74,12 @@ class ImageMediaIO(MediaIO[Image.Image]):
|
||||
)
|
||||
self.rgba_background_color = rgba_bg
|
||||
|
||||
def _convert_image_mode(self, image: Image.Image) -> Image.Image:
|
||||
def _convert_image_mode(
|
||||
self, image: Image.Image | MediaWithBytes[Image.Image]
|
||||
) -> Image.Image:
|
||||
"""Convert image mode with custom background color."""
|
||||
if isinstance(image, MediaWithBytes):
|
||||
image = image.media
|
||||
if image.mode == self.image_mode:
|
||||
return image
|
||||
elif image.mode == "RGBA" and self.image_mode == "RGB":
|
||||
@ -83,18 +87,18 @@ class ImageMediaIO(MediaIO[Image.Image]):
|
||||
else:
|
||||
return convert_image_mode(image, self.image_mode)
|
||||
|
||||
def load_bytes(self, data: bytes) -> Image.Image:
|
||||
def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
|
||||
image = Image.open(BytesIO(data))
|
||||
image.load()
|
||||
return self._convert_image_mode(image)
|
||||
return MediaWithBytes(self._convert_image_mode(image), data)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> Image.Image:
|
||||
def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> Image.Image:
|
||||
image = Image.open(filepath)
|
||||
image.load()
|
||||
return self._convert_image_mode(image)
|
||||
def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
|
||||
with open(filepath, "rb") as f:
|
||||
data = f.read()
|
||||
image = Image.open(BytesIO(data))
|
||||
return MediaWithBytes(self._convert_image_mode(image), data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
from .audio import AudioResampler
|
||||
from .base import MediaWithBytes
|
||||
from .inputs import (
|
||||
AudioItem,
|
||||
HfAudioItem,
|
||||
@ -84,6 +85,12 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
"""Get all data items."""
|
||||
return [self.get(idx) for idx in range(self.get_count())]
|
||||
|
||||
def get_item_for_hash(self, index: int) -> object:
|
||||
return self.get(index)
|
||||
|
||||
def get_all_items_for_hash(self) -> list[object]:
|
||||
return [self.get_item_for_hash(idx) for idx in range(self.get_count())]
|
||||
|
||||
@abstractmethod
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
"""Get the data to pass to the HF processor."""
|
||||
@ -98,10 +105,18 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
|
||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||
"""Base class for data items that are arranged in a list."""
|
||||
|
||||
def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
|
||||
"""Extract media from wrapper if present."""
|
||||
return item.media if isinstance(item, MediaWithBytes) else item
|
||||
|
||||
def get_count(self) -> int:
|
||||
return len(self.data)
|
||||
|
||||
def get(self, index: int) -> _T:
|
||||
return self._unwrap(self.data[index])
|
||||
|
||||
def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
|
||||
# Return raw item for hashing (preserves original_bytes if present)
|
||||
return self.data[index]
|
||||
|
||||
def get_processor_data(self) -> Mapping[str, object]:
|
||||
|
||||
@ -1684,7 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
# For None entries, compute a hash; otherwise, use provided ID.
|
||||
computed: list[str] = []
|
||||
for i, item in enumerate(items):
|
||||
for i, item in enumerate(items.get_all_items_for_hash()):
|
||||
item_uuid = mm_uuids_per_modality[i]
|
||||
|
||||
# NOTE: Even if a item_uuid is provided, we still compute a
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user