vllm/vllm/multimodal/image.py
Chen Zhang 770ec6024f
[Model] Add support for the multi-modal Llama 3.2 model (#8811)
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-09-25 13:29:32 -07:00

72 lines
2.4 KiB
Python

from functools import lru_cache
import torch
from PIL import Image
from transformers.image_processing_base import BatchFeature
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of
from .base import MultiModalData, MultiModalInputs, MultiModalPlugin
logger = init_logger(__name__)
cached_get_image_processor = lru_cache(get_image_processor)
class ImagePlugin(MultiModalPlugin):
"""Plugin for image data."""
def get_data_key(self) -> str:
return "image"
def _get_hf_image_processor(self, model_config: ModelConfig):
mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None
else model_config.mm_processor_kwargs)
# We don't explicitly check kwarg overrides to the HF class
# since the automodel just takes kwargs, so we can't inspect it
return cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs)
def _default_input_mapper(
self,
ctx: InputContext,
data: MultiModalData[object],
) -> MultiModalInputs:
model_config = ctx.model_config
# Processed by input processor
if isinstance(data, BatchFeature):
return MultiModalInputs(data.data)
# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available "
"to process the image object")
try:
batch_data = image_processor \
.preprocess(data, return_tensors="pt") \
.data
except Exception:
logger.error("Failed to process image (%s)", data)
raise
return MultiModalInputs(batch_data)
# Image embedding
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
return MultiModalInputs({"image_embeds": data})
raise TypeError(f"Invalid image type: {type(data)}")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 3000