vllm/vllm/multimodal/utils.py
Russell Bryant 32335c8b34 Add option to restrict media domains (#25783)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-27 23:32:55 -07:00

496 lines
16 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import atexit
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse
from urllib.request import url2pathname
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image, UnidentifiedImageError
from typing_extensions import deprecated
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.utils.jsontree import json_map_leaves
from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .video import VideoMediaIO
_M = TypeVar("_M")
if TYPE_CHECKING:
from .inputs import (BatchedTensorInputs, MultiModalKwargsItem,
MultiModalKwargsItems, MultiModalPlaceholderDict)
else:
BatchedTensorInputs = Any
MultiModalKwargsItem = Any
MultiModalKwargsItems = Any
MultiModalPlaceholderDict = Any
global_thread_pool = ThreadPoolExecutor(
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT)
atexit.register(global_thread_pool.shutdown)
class MediaConnector:
def __init__(
self,
media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None,
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
allowed_media_domains: Optional[list[str]] = None,
) -> None:
"""
Args:
media_io_kwargs: Additional args passed to process media
inputs, keyed by modalities. For example,
to set num_frames for video, set
`--media-io-kwargs '{"video":{"num_frames":40}}'`
connection: HTTP connection client to download media contents.
allowed_local_media_path: A local directory to load media files
from.
"""
super().__init__()
self.media_io_kwargs: dict[str, dict[
str, Any]] = media_io_kwargs if media_io_kwargs else {}
self.connection = connection
if allowed_local_media_path:
allowed_local_media_path_ = Path(allowed_local_media_path)
if not allowed_local_media_path_.exists():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} does not exist.")
if not allowed_local_media_path_.is_dir():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} must be a directory.")
else:
allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_
if allowed_media_domains is None:
allowed_media_domains = []
self.allowed_media_domains = allowed_media_domains
def _load_data_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M: # type: ignore[type-var]
data_spec, data = url_spec.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
return media_io.load_base64(media_type, data)
def _load_file_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M: # type: ignore[type-var]
allowed_local_media_path = self.allowed_local_media_path
if allowed_local_media_path is None:
raise RuntimeError("Cannot load local files without "
"`--allowed-local-media-path`.")
filepath = Path(url2pathname(url_spec.path))
if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError(
f"The file path {filepath} must be a subpath "
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
return media_io.load_file(filepath)
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
if self.allowed_media_domains and url_spec.hostname not in \
self.allowed_media_domains:
raise ValueError(
f"The URL must be from one of the allowed domains: "
f"{self.allowed_media_domains}. Input URL domain: "
f"{url_spec.hostname}")
def load_from_url(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M: # type: ignore[type-var]
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
async def load_from_url_async(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M:
url_spec = urlparse(url)
loop = asyncio.get_running_loop()
if url_spec.scheme.startswith("http"):
self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
future = loop.run_in_executor(global_thread_pool,
media_io.load_bytes, data)
return await future
if url_spec.scheme == "data":
future = loop.run_in_executor(global_thread_pool,
self._load_data_url, url_spec,
media_io)
return await future
if url_spec.scheme == "file":
future = loop.run_in_executor(global_thread_pool,
self._load_file_url, url_spec,
media_io)
return await future
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
def fetch_audio(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
return self.load_from_url(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
async def fetch_audio_async(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
return await self.load_from_url_async(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
def fetch_image(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Load a PIL image from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode,
**self.media_io_kwargs.get("image", {}))
try:
return self.load_from_url(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
except UnidentifiedImageError as e:
# convert to ValueError to be properly caught upstream
raise ValueError(str(e)) from e
async def fetch_image_async(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode,
**self.media_io_kwargs.get("image", {}))
try:
return await self.load_from_url_async(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
except UnidentifiedImageError as e:
# convert to ValueError to be properly caught upstream
raise ValueError(str(e)) from e
def fetch_video(
self,
video_url: str,
*,
image_mode: str = "RGB",
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Load video from an HTTP or base64 data URL.
"""
image_io = ImageMediaIO(image_mode=image_mode,
**self.media_io_kwargs.get("image", {}))
video_io = VideoMediaIO(image_io,
**self.media_io_kwargs.get("video", {}))
return self.load_from_url(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def fetch_video_async(
self,
video_url: str,
*,
image_mode: str = "RGB",
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Asynchronously load video from an HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode,
**self.media_io_kwargs.get("image", {}))
video_io = VideoMediaIO(image_io,
**self.media_io_kwargs.get("video", {}))
return await self.load_from_url_async(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
def fetch_image_embedding(
self,
data: str,
) -> torch.Tensor:
"""
Load image embedding from a URL.
"""
image_embedding_io = ImageEmbeddingMediaIO()
return image_embedding_io.load_base64("", data)
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
audio_io = AudioMediaIO()
return audio_io.encode_base64((audio, sampling_rate))
def encode_image_base64(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "JPEG",
) -> str:
"""
Encode a pillow image to base64 format.
By default, the image is converted into RGB format before being encoded.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return image_io.encode_base64(image, image_format=format)
def encode_video_base64(frames: npt.NDArray) -> str:
image_io = ImageMediaIO()
video_io = VideoMediaIO(image_io)
return video_io.encode_base64(frames)
def argsort_mm_positions(
mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
"""
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
sort the dictionary by `offset` (starting index in the input sequence)
in ascending order.
Returns:
A list of `(modality, idx)`, which can be used to access an item
by `mm_positions[modality][idx]`.
"""
flat_items = ((modality, idx, item)
for modality, items in mm_positions.items()
for idx, item in enumerate(items))
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
# Temporary back-compatibility for plugins that define model runner
@deprecated("`group_mm_inputs_by_modality` is superseded by "
"`group_mm_kwargs_by_modality` and will be removed in v0.13. "
"Please use `group_mm_kwargs_by_modality` instead.")
def group_mm_inputs_by_modality(
mm_inputs: list[MultiModalKwargsItems]
) -> list[list[MultiModalKwargsItems]]:
if not mm_inputs:
return []
def modality_group_func(
mm_input: MultiModalKwargsItems) -> Union[str, int]:
# If the input has multiple modalities, return an id as the unique key
# for the mm_input input.
if len(mm_input) > 1:
return id(mm_input)
elif len(mm_input) == 1:
return next(iter(mm_input.keys()))
raise AssertionError("This line should be unreachable.")
return [
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
]
def group_mm_kwargs_by_modality(
mm_kwargs: list[MultiModalKwargsItem],
*,
device: torch.types.Device = None,
pin_memory: bool = False,
merge_by_field_config: bool = False,
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
Args:
mm_kwargs: List of `MultiModalKwargsItem`.
device: The device to place the grouped tensors on.
pin_memory: Whether to pin memory for faster host-to-device transfer.
Yields:
A tuple `(modality, num_items, grouped_kwargs)`.
"""
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
items_lst = list(items)
# TODO: Enable `merge_by_field_config` for all models
# to avoid creating an extra batch dimension (except for fields
# that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`.
if merge_by_field_config:
mm_kwargs_group: BatchedTensorInputs = dict(
MultiModalKwargsItems.from_seq(items_lst).get_data(
pin_memory=pin_memory))
if device is not None:
mm_kwargs_group = json_map_leaves(
lambda x: x.to(device=device),
mm_kwargs_group,
)
else:
mm_kwargs_group = MultiModalKwargs.as_kwargs(
MultiModalKwargs.batch(
[
MultiModalKwargsItems.from_seq([item]).get_data()
for item in items_lst
],
pin_memory=pin_memory,
),
device=device,
)
yield modality, len(items_lst), mm_kwargs_group
def fetch_audio(
audio_url: str,
audio_io_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Args:
audio_url: URL of the audio file to fetch.
audio_io_kwargs: Additional kwargs passed to handle audio IO.
"""
media_io_kwargs = None if not audio_io_kwargs else {
"audio": audio_io_kwargs
}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
return media_connector.fetch_audio(audio_url)
def fetch_image(
image_url: str,
image_io_kwargs: Optional[dict[str, Any]] = None,
) -> Image.Image:
"""
Args:
image_url: URL of the image file to fetch.
image_io_kwargs: Additional kwargs passed to handle image IO.
"""
media_io_kwargs = None if not image_io_kwargs else {
"image": image_io_kwargs
}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
return media_connector.fetch_image(image_url)
def fetch_video(
video_url: str,
video_io_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[npt.NDArray, dict[str, Any]]:
"""
Args:
video_url: URL of the video file to fetch.
video_io_kwargs: Additional kwargs passed to handle video IO.
"""
media_io_kwargs = None if not video_io_kwargs else {
"video": video_io_kwargs
}
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
return media_connector.fetch_video(video_url)