mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 17:42:16 +08:00
[Multimodal][Core] Optimize multimodal preprocessing cache by hashing image bytes instead of pixel values (#29621)
Signed-off-by: Rahul Steiger <rasteiger@ethz.ch> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
68ffbca7e4
commit
60c3d413af
@ -59,6 +59,7 @@ from vllm.distributed import (
|
|||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logprobs import Logprob
|
from vllm.logprobs import Logprob
|
||||||
|
from vllm.multimodal.base import MediaWithBytes
|
||||||
from vllm.multimodal.utils import fetch_image
|
from vllm.multimodal.utils import fetch_image
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
@ -1389,7 +1390,11 @@ class LocalAssetServer:
|
|||||||
return f"{self.base_url}/{name}"
|
return f"{self.base_url}/{name}"
|
||||||
|
|
||||||
def get_image_asset(self, name: str) -> Image.Image:
|
def get_image_asset(self, name: str) -> Image.Image:
|
||||||
return fetch_image(self.url_for(name))
|
image = fetch_image(self.url_for(name))
|
||||||
|
# Unwrap MediaWithBytes if present
|
||||||
|
if isinstance(image, MediaWithBytes):
|
||||||
|
image = image.media
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import pytest
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
from vllm.multimodal.base import MediaWithBytes
|
||||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
@ -111,7 +112,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
|
|||||||
"content": f"{placeholder}{content}",
|
"content": f"{placeholder}{content}",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
images = [fetch_image(image_url)]
|
image = fetch_image(image_url)
|
||||||
|
# Unwrap MediaWithBytes if present
|
||||||
|
if isinstance(image, MediaWithBytes):
|
||||||
|
image = image.media
|
||||||
|
images = [image]
|
||||||
|
|
||||||
prompt = processor.tokenizer.apply_chat_template(
|
prompt = processor.tokenizer.apply_chat_template(
|
||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from transformers import AutoProcessor
|
|||||||
|
|
||||||
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
from tests.utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
|
||||||
|
from vllm.multimodal.base import MediaWithBytes
|
||||||
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
from vllm.multimodal.utils import encode_image_base64, fetch_image
|
||||||
|
|
||||||
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
MODEL_NAME = "TIGER-Lab/VLM2Vec-Full"
|
||||||
@ -62,7 +63,11 @@ def get_hf_prompt_tokens(model_name, content, image_url):
|
|||||||
|
|
||||||
placeholder = "<|image_1|> "
|
placeholder = "<|image_1|> "
|
||||||
prompt = f"{placeholder}{content}"
|
prompt = f"{placeholder}{content}"
|
||||||
images = [fetch_image(image_url)]
|
image = fetch_image(image_url)
|
||||||
|
# Unwrap MediaWithBytes if present
|
||||||
|
if isinstance(image, MediaWithBytes):
|
||||||
|
image = image.media
|
||||||
|
images = [image]
|
||||||
inputs = processor(prompt, images, return_tensors="pt")
|
inputs = processor(prompt, images, return_tensors="pt")
|
||||||
return inputs.input_ids.shape[1]
|
return inputs.input_ids.shape[1]
|
||||||
|
|
||||||
|
|||||||
@ -2,12 +2,40 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MediaWithBytes(Generic[_T]):
|
||||||
|
"""
|
||||||
|
Wrapper that couples a media object with its original encoded bytes.
|
||||||
|
|
||||||
|
This ensures the raw bytes and media object remain synchronized,
|
||||||
|
preventing cache corruption from in-place modifications.
|
||||||
|
|
||||||
|
The wrapper delegates attribute access to the underlying media object,
|
||||||
|
making it behave transparently like the wrapped type (e.g., PIL.Image).
|
||||||
|
"""
|
||||||
|
|
||||||
|
media: _T
|
||||||
|
original_bytes: bytes
|
||||||
|
|
||||||
|
def __array__(self, *args, **kwargs) -> np.ndarray:
|
||||||
|
"""Allow np.array(obj) to return np.array(obj.media)."""
|
||||||
|
return np.array(self.media, *args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str):
|
||||||
|
"""Delegate attribute access to the underlying media object."""
|
||||||
|
# This is only called when the attribute is not found on self
|
||||||
|
return getattr(self.media, name)
|
||||||
|
|
||||||
|
|
||||||
class MediaIO(ABC, Generic[_T]):
|
class MediaIO(ABC, Generic[_T]):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_bytes(self, data: bytes) -> _T:
|
def load_bytes(self, data: bytes) -> _T:
|
||||||
|
|||||||
@ -12,6 +12,8 @@ from PIL import Image
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
from .base import MediaWithBytes
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -31,14 +33,26 @@ class MultiModalHasher:
|
|||||||
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||||
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||||
):
|
):
|
||||||
# If the image has exif ImageID tag, use that
|
|
||||||
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||||
|
|
||||||
data = {"mode": obj.mode, "data": np.asarray(obj)}
|
data = {"mode": obj.mode, "data": np.asarray(obj)}
|
||||||
if obj.palette is not None:
|
palette = obj.palette
|
||||||
data["palette"] = obj.palette.palette
|
if palette is not None:
|
||||||
if obj.palette.rawmode is not None:
|
data["palette"] = palette.palette
|
||||||
data["palette_rawmode"] = obj.palette.rawmode
|
if palette.rawmode is not None:
|
||||||
|
data["palette_rawmode"] = palette.rawmode
|
||||||
|
|
||||||
return cls.iter_item_to_bytes("image", data)
|
return cls.iter_item_to_bytes("image", data)
|
||||||
|
|
||||||
|
if isinstance(obj, MediaWithBytes) and isinstance(obj.media, Image.Image):
|
||||||
|
exif = obj.media.getexif()
|
||||||
|
if Image.ExifTags.Base.ImageID in exif and isinstance(
|
||||||
|
exif[Image.ExifTags.Base.ImageID], uuid.UUID
|
||||||
|
):
|
||||||
|
return (exif[Image.ExifTags.Base.ImageID].bytes,)
|
||||||
|
|
||||||
|
return cls.iter_item_to_bytes("image", obj.original_bytes)
|
||||||
|
|
||||||
if isinstance(obj, torch.Tensor):
|
if isinstance(obj, torch.Tensor):
|
||||||
tensor_obj: torch.Tensor = obj.cpu()
|
tensor_obj: torch.Tensor = obj.cpu()
|
||||||
tensor_dtype = tensor_obj.dtype
|
tensor_dtype = tensor_obj.dtype
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import pybase64
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from .base import MediaIO
|
from .base import MediaIO, MediaWithBytes
|
||||||
|
|
||||||
|
|
||||||
def rescale_image_size(
|
def rescale_image_size(
|
||||||
@ -74,8 +74,12 @@ class ImageMediaIO(MediaIO[Image.Image]):
|
|||||||
)
|
)
|
||||||
self.rgba_background_color = rgba_bg
|
self.rgba_background_color = rgba_bg
|
||||||
|
|
||||||
def _convert_image_mode(self, image: Image.Image) -> Image.Image:
|
def _convert_image_mode(
|
||||||
|
self, image: Image.Image | MediaWithBytes[Image.Image]
|
||||||
|
) -> Image.Image:
|
||||||
"""Convert image mode with custom background color."""
|
"""Convert image mode with custom background color."""
|
||||||
|
if isinstance(image, MediaWithBytes):
|
||||||
|
image = image.media
|
||||||
if image.mode == self.image_mode:
|
if image.mode == self.image_mode:
|
||||||
return image
|
return image
|
||||||
elif image.mode == "RGBA" and self.image_mode == "RGB":
|
elif image.mode == "RGBA" and self.image_mode == "RGB":
|
||||||
@ -83,18 +87,18 @@ class ImageMediaIO(MediaIO[Image.Image]):
|
|||||||
else:
|
else:
|
||||||
return convert_image_mode(image, self.image_mode)
|
return convert_image_mode(image, self.image_mode)
|
||||||
|
|
||||||
def load_bytes(self, data: bytes) -> Image.Image:
|
def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
|
||||||
image = Image.open(BytesIO(data))
|
image = Image.open(BytesIO(data))
|
||||||
image.load()
|
return MediaWithBytes(self._convert_image_mode(image), data)
|
||||||
return self._convert_image_mode(image)
|
|
||||||
|
|
||||||
def load_base64(self, media_type: str, data: str) -> Image.Image:
|
def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
|
||||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||||
|
|
||||||
def load_file(self, filepath: Path) -> Image.Image:
|
def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
|
||||||
image = Image.open(filepath)
|
with open(filepath, "rb") as f:
|
||||||
image.load()
|
data = f.read()
|
||||||
return self._convert_image_mode(image)
|
image = Image.open(BytesIO(data))
|
||||||
|
return MediaWithBytes(self._convert_image_mode(image), data)
|
||||||
|
|
||||||
def encode_base64(
|
def encode_base64(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from vllm.utils.collection_utils import is_list_of
|
|||||||
from vllm.utils.import_utils import LazyLoader
|
from vllm.utils.import_utils import LazyLoader
|
||||||
|
|
||||||
from .audio import AudioResampler
|
from .audio import AudioResampler
|
||||||
|
from .base import MediaWithBytes
|
||||||
from .inputs import (
|
from .inputs import (
|
||||||
AudioItem,
|
AudioItem,
|
||||||
HfAudioItem,
|
HfAudioItem,
|
||||||
@ -84,6 +85,12 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
|
|||||||
"""Get all data items."""
|
"""Get all data items."""
|
||||||
return [self.get(idx) for idx in range(self.get_count())]
|
return [self.get(idx) for idx in range(self.get_count())]
|
||||||
|
|
||||||
|
def get_item_for_hash(self, index: int) -> object:
|
||||||
|
return self.get(index)
|
||||||
|
|
||||||
|
def get_all_items_for_hash(self) -> list[object]:
|
||||||
|
return [self.get_item_for_hash(idx) for idx in range(self.get_count())]
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_processor_data(self) -> Mapping[str, object]:
|
def get_processor_data(self) -> Mapping[str, object]:
|
||||||
"""Get the data to pass to the HF processor."""
|
"""Get the data to pass to the HF processor."""
|
||||||
@ -98,10 +105,18 @@ class ModalityDataItems(ABC, Generic[_T, _I]):
|
|||||||
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
class ProcessorBatchItems(ModalityDataItems[Sequence[_T], _T]):
|
||||||
"""Base class for data items that are arranged in a list."""
|
"""Base class for data items that are arranged in a list."""
|
||||||
|
|
||||||
|
def _unwrap(self, item: _T | MediaWithBytes[_T]) -> _T:
|
||||||
|
"""Extract media from wrapper if present."""
|
||||||
|
return item.media if isinstance(item, MediaWithBytes) else item
|
||||||
|
|
||||||
def get_count(self) -> int:
|
def get_count(self) -> int:
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def get(self, index: int) -> _T:
|
def get(self, index: int) -> _T:
|
||||||
|
return self._unwrap(self.data[index])
|
||||||
|
|
||||||
|
def get_item_for_hash(self, index: int) -> _T | MediaWithBytes[_T]:
|
||||||
|
# Return raw item for hashing (preserves original_bytes if present)
|
||||||
return self.data[index]
|
return self.data[index]
|
||||||
|
|
||||||
def get_processor_data(self) -> Mapping[str, object]:
|
def get_processor_data(self) -> Mapping[str, object]:
|
||||||
|
|||||||
@ -1684,7 +1684,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
|
|
||||||
# For None entries, compute a hash; otherwise, use provided ID.
|
# For None entries, compute a hash; otherwise, use provided ID.
|
||||||
computed: list[str] = []
|
computed: list[str] = []
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items.get_all_items_for_hash()):
|
||||||
item_uuid = mm_uuids_per_modality[i]
|
item_uuid = mm_uuids_per_modality[i]
|
||||||
|
|
||||||
# NOTE: Even if a item_uuid is provided, we still compute a
|
# NOTE: Even if a item_uuid is provided, we still compute a
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user