mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:06:03 +08:00
480 lines
16 KiB
Python
480 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import asyncio
|
|
import atexit
|
|
from collections.abc import Iterable
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from itertools import groupby
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
|
|
from urllib.parse import ParseResult, urlparse
|
|
from urllib.request import url2pathname
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
from PIL import Image, UnidentifiedImageError
|
|
|
|
import vllm.envs as envs
|
|
from vllm.connections import HTTPConnection, global_http_connection
|
|
from vllm.utils.jsontree import json_map_leaves
|
|
|
|
from .audio import AudioMediaIO
|
|
from .base import MediaIO
|
|
from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
|
from .video import VideoMediaIO
|
|
|
|
_M = TypeVar("_M")
|
|
|
|
if TYPE_CHECKING:
|
|
from .inputs import (BatchedTensorInputs, MultiModalKwargsItem,
|
|
MultiModalKwargsItems, MultiModalPlaceholderDict)
|
|
else:
|
|
BatchedTensorInputs = Any
|
|
MultiModalKwargsItem = Any
|
|
MultiModalKwargsItems = Any
|
|
MultiModalPlaceholderDict = Any
|
|
|
|
global_thread_pool = ThreadPoolExecutor(
|
|
max_workers=envs.VLLM_MEDIA_LOADING_THREAD_COUNT)
|
|
atexit.register(global_thread_pool.shutdown)
|
|
|
|
|
|
class MediaConnector:
|
|
|
|
def __init__(
|
|
self,
|
|
media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None,
|
|
connection: HTTPConnection = global_http_connection,
|
|
*,
|
|
allowed_local_media_path: str = "",
|
|
allowed_media_domains: Optional[list[str]] = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
media_io_kwargs: Additional args passed to process media
|
|
inputs, keyed by modalities. For example,
|
|
to set num_frames for video, set
|
|
`--media-io-kwargs '{"video":{"num_frames":40}}'`
|
|
connection: HTTP connection client to download media contents.
|
|
allowed_local_media_path: A local directory to load media files
|
|
from.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.media_io_kwargs: dict[str, dict[
|
|
str, Any]] = media_io_kwargs if media_io_kwargs else {}
|
|
self.connection = connection
|
|
|
|
if allowed_local_media_path:
|
|
allowed_local_media_path_ = Path(allowed_local_media_path)
|
|
|
|
if not allowed_local_media_path_.exists():
|
|
raise ValueError(
|
|
"Invalid `--allowed-local-media-path`: The path "
|
|
f"{allowed_local_media_path_} does not exist.")
|
|
if not allowed_local_media_path_.is_dir():
|
|
raise ValueError(
|
|
"Invalid `--allowed-local-media-path`: The path "
|
|
f"{allowed_local_media_path_} must be a directory.")
|
|
else:
|
|
allowed_local_media_path_ = None
|
|
|
|
self.allowed_local_media_path = allowed_local_media_path_
|
|
if allowed_media_domains is None:
|
|
allowed_media_domains = []
|
|
self.allowed_media_domains = allowed_media_domains
|
|
|
|
def _load_data_url(
|
|
self,
|
|
url_spec: ParseResult,
|
|
media_io: MediaIO[_M],
|
|
) -> _M: # type: ignore[type-var]
|
|
data_spec, data = url_spec.path.split(",", 1)
|
|
media_type, data_type = data_spec.split(";", 1)
|
|
|
|
if data_type != "base64":
|
|
msg = "Only base64 data URLs are supported for now."
|
|
raise NotImplementedError(msg)
|
|
|
|
return media_io.load_base64(media_type, data)
|
|
|
|
def _load_file_url(
|
|
self,
|
|
url_spec: ParseResult,
|
|
media_io: MediaIO[_M],
|
|
) -> _M: # type: ignore[type-var]
|
|
allowed_local_media_path = self.allowed_local_media_path
|
|
if allowed_local_media_path is None:
|
|
raise RuntimeError("Cannot load local files without "
|
|
"`--allowed-local-media-path`.")
|
|
|
|
filepath = Path(url2pathname(url_spec.path))
|
|
if allowed_local_media_path not in filepath.resolve().parents:
|
|
raise ValueError(
|
|
f"The file path {filepath} must be a subpath "
|
|
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
|
|
|
|
return media_io.load_file(filepath)
|
|
|
|
def _assert_url_in_allowed_media_domains(self, url_spec) -> None:
|
|
if self.allowed_media_domains and url_spec.hostname not in \
|
|
self.allowed_media_domains:
|
|
raise ValueError(
|
|
f"The URL must be from one of the allowed domains: "
|
|
f"{self.allowed_media_domains}. Input URL domain: "
|
|
f"{url_spec.hostname}")
|
|
|
|
def load_from_url(
|
|
self,
|
|
url: str,
|
|
media_io: MediaIO[_M],
|
|
*,
|
|
fetch_timeout: Optional[int] = None,
|
|
) -> _M: # type: ignore[type-var]
|
|
url_spec = urlparse(url)
|
|
|
|
if url_spec.scheme.startswith("http"):
|
|
self._assert_url_in_allowed_media_domains(url_spec)
|
|
|
|
connection = self.connection
|
|
data = connection.get_bytes(
|
|
url,
|
|
timeout=fetch_timeout,
|
|
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
|
)
|
|
|
|
return media_io.load_bytes(data)
|
|
|
|
if url_spec.scheme == "data":
|
|
return self._load_data_url(url_spec, media_io)
|
|
|
|
if url_spec.scheme == "file":
|
|
return self._load_file_url(url_spec, media_io)
|
|
|
|
msg = "The URL must be either a HTTP, data or file URL."
|
|
raise ValueError(msg)
|
|
|
|
async def load_from_url_async(
|
|
self,
|
|
url: str,
|
|
media_io: MediaIO[_M],
|
|
*,
|
|
fetch_timeout: Optional[int] = None,
|
|
) -> _M:
|
|
url_spec = urlparse(url)
|
|
loop = asyncio.get_running_loop()
|
|
|
|
if url_spec.scheme.startswith("http"):
|
|
self._assert_url_in_allowed_media_domains(url_spec)
|
|
|
|
connection = self.connection
|
|
data = await connection.async_get_bytes(
|
|
url,
|
|
timeout=fetch_timeout,
|
|
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
|
)
|
|
future = loop.run_in_executor(global_thread_pool,
|
|
media_io.load_bytes, data)
|
|
return await future
|
|
|
|
if url_spec.scheme == "data":
|
|
future = loop.run_in_executor(global_thread_pool,
|
|
self._load_data_url, url_spec,
|
|
media_io)
|
|
return await future
|
|
|
|
if url_spec.scheme == "file":
|
|
future = loop.run_in_executor(global_thread_pool,
|
|
self._load_file_url, url_spec,
|
|
media_io)
|
|
return await future
|
|
msg = "The URL must be either a HTTP, data or file URL."
|
|
raise ValueError(msg)
|
|
|
|
def fetch_audio(
|
|
self,
|
|
audio_url: str,
|
|
) -> tuple[np.ndarray, Union[int, float]]:
|
|
"""
|
|
Load audio from a URL.
|
|
"""
|
|
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
|
|
|
return self.load_from_url(
|
|
audio_url,
|
|
audio_io,
|
|
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
|
)
|
|
|
|
async def fetch_audio_async(
|
|
self,
|
|
audio_url: str,
|
|
) -> tuple[np.ndarray, Union[int, float]]:
|
|
"""
|
|
Asynchronously fetch audio from a URL.
|
|
"""
|
|
audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
|
|
|
|
return await self.load_from_url_async(
|
|
audio_url,
|
|
audio_io,
|
|
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
|
)
|
|
|
|
def fetch_image(
|
|
self,
|
|
image_url: str,
|
|
*,
|
|
image_mode: str = "RGB",
|
|
) -> Image.Image:
|
|
"""
|
|
Load a PIL image from an HTTP or base64 data URL.
|
|
|
|
By default, the image is converted into RGB format.
|
|
"""
|
|
image_io = ImageMediaIO(image_mode=image_mode,
|
|
**self.media_io_kwargs.get("image", {}))
|
|
|
|
try:
|
|
return self.load_from_url(
|
|
image_url,
|
|
image_io,
|
|
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
|
)
|
|
except UnidentifiedImageError as e:
|
|
# convert to ValueError to be properly caught upstream
|
|
raise ValueError(str(e)) from e
|
|
|
|
async def fetch_image_async(
|
|
self,
|
|
image_url: str,
|
|
*,
|
|
image_mode: str = "RGB",
|
|
) -> Image.Image:
|
|
"""
|
|
Asynchronously load a PIL image from an HTTP or base64 data URL.
|
|
|
|
By default, the image is converted into RGB format.
|
|
"""
|
|
image_io = ImageMediaIO(image_mode=image_mode,
|
|
**self.media_io_kwargs.get("image", {}))
|
|
|
|
try:
|
|
return await self.load_from_url_async(
|
|
image_url,
|
|
image_io,
|
|
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
|
)
|
|
except UnidentifiedImageError as e:
|
|
# convert to ValueError to be properly caught upstream
|
|
raise ValueError(str(e)) from e
|
|
|
|
def fetch_video(
|
|
self,
|
|
video_url: str,
|
|
*,
|
|
image_mode: str = "RGB",
|
|
) -> tuple[npt.NDArray, dict[str, Any]]:
|
|
"""
|
|
Load video from an HTTP or base64 data URL.
|
|
"""
|
|
image_io = ImageMediaIO(image_mode=image_mode,
|
|
**self.media_io_kwargs.get("image", {}))
|
|
video_io = VideoMediaIO(image_io,
|
|
**self.media_io_kwargs.get("video", {}))
|
|
|
|
return self.load_from_url(
|
|
video_url,
|
|
video_io,
|
|
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
|
)
|
|
|
|
async def fetch_video_async(
|
|
self,
|
|
video_url: str,
|
|
*,
|
|
image_mode: str = "RGB",
|
|
) -> tuple[npt.NDArray, dict[str, Any]]:
|
|
"""
|
|
Asynchronously load video from an HTTP or base64 data URL.
|
|
|
|
By default, the image is converted into RGB format.
|
|
"""
|
|
image_io = ImageMediaIO(image_mode=image_mode,
|
|
**self.media_io_kwargs.get("image", {}))
|
|
video_io = VideoMediaIO(image_io,
|
|
**self.media_io_kwargs.get("video", {}))
|
|
|
|
return await self.load_from_url_async(
|
|
video_url,
|
|
video_io,
|
|
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
|
)
|
|
|
|
def fetch_image_embedding(
|
|
self,
|
|
data: str,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Load image embedding from a URL.
|
|
"""
|
|
image_embedding_io = ImageEmbeddingMediaIO()
|
|
|
|
return image_embedding_io.load_base64("", data)
|
|
|
|
|
|
def encode_audio_base64(
|
|
audio: np.ndarray,
|
|
sampling_rate: int,
|
|
) -> str:
|
|
"""Encode audio as base64."""
|
|
audio_io = AudioMediaIO()
|
|
return audio_io.encode_base64((audio, sampling_rate))
|
|
|
|
|
|
def encode_image_base64(
|
|
image: Image.Image,
|
|
*,
|
|
image_mode: str = "RGB",
|
|
format: str = "JPEG",
|
|
) -> str:
|
|
"""
|
|
Encode a pillow image to base64 format.
|
|
|
|
By default, the image is converted into RGB format before being encoded.
|
|
"""
|
|
image_io = ImageMediaIO(image_mode=image_mode)
|
|
return image_io.encode_base64(image, image_format=format)
|
|
|
|
|
|
def encode_video_base64(frames: npt.NDArray) -> str:
|
|
image_io = ImageMediaIO()
|
|
video_io = VideoMediaIO(image_io)
|
|
return video_io.encode_base64(frames)
|
|
|
|
|
|
def argsort_mm_positions(
|
|
mm_positions: MultiModalPlaceholderDict) -> list[tuple[str, int]]:
|
|
"""
|
|
Given a `MultiModalPlaceholderDict`, output a sequence of keys to
|
|
sort the dictionary by `offset` (starting index in the input sequence)
|
|
in ascending order.
|
|
|
|
Returns:
|
|
A list of `(modality, idx)`, which can be used to access an item
|
|
by `mm_positions[modality][idx]`.
|
|
"""
|
|
flat_items = ((modality, idx, item)
|
|
for modality, items in mm_positions.items()
|
|
for idx, item in enumerate(items))
|
|
|
|
sorted_flat_items = sorted(flat_items, key=lambda x: x[2].offset)
|
|
|
|
return [(modality, idx) for modality, idx, _ in sorted_flat_items]
|
|
|
|
|
|
def group_mm_kwargs_by_modality(
|
|
mm_kwargs: list[MultiModalKwargsItem],
|
|
*,
|
|
device: torch.types.Device = None,
|
|
pin_memory: bool = False,
|
|
merge_by_field_config: Optional[bool] = None,
|
|
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
|
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
|
|
modality together into the same `MultiModalKwargs` instance.
|
|
|
|
Args:
|
|
mm_kwargs: List of `MultiModalKwargsItem`.
|
|
device: The device to place the grouped tensors on.
|
|
pin_memory: Whether to pin memory for faster host-to-device transfer.
|
|
|
|
Yields:
|
|
A tuple `(modality, num_items, grouped_kwargs)`.
|
|
"""
|
|
if merge_by_field_config is None:
|
|
raise RuntimeError(
|
|
"`group_mm_kwargs_by_modality` now requires "
|
|
"`merge_by_field_config` arg, please update your model runner "
|
|
"according to https://github.com/vllm-project/vllm/pull/25676.")
|
|
|
|
from vllm.multimodal.inputs import MultiModalKwargs, MultiModalKwargsItems
|
|
|
|
for modality, items in groupby(mm_kwargs, key=lambda item: item.modality):
|
|
items_lst = list(items)
|
|
|
|
# TODO: Deprecate `merge_by_field_config` once
|
|
# we have migrated all in-tree models
|
|
if merge_by_field_config:
|
|
mm_kwargs_group: BatchedTensorInputs = dict(
|
|
MultiModalKwargsItems.from_seq(items_lst).get_data(
|
|
pin_memory=pin_memory))
|
|
|
|
if device is not None:
|
|
mm_kwargs_group = json_map_leaves(
|
|
lambda x: x.to(device=device),
|
|
mm_kwargs_group,
|
|
)
|
|
else:
|
|
mm_kwargs_group = MultiModalKwargs.as_kwargs(
|
|
MultiModalKwargs.batch(
|
|
[
|
|
MultiModalKwargsItems.from_seq([item]).get_data()
|
|
for item in items_lst
|
|
],
|
|
pin_memory=pin_memory,
|
|
),
|
|
device=device,
|
|
)
|
|
|
|
yield modality, len(items_lst), mm_kwargs_group
|
|
|
|
|
|
def fetch_audio(
|
|
audio_url: str,
|
|
audio_io_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> tuple[np.ndarray, Union[int, float]]:
|
|
"""
|
|
Args:
|
|
audio_url: URL of the audio file to fetch.
|
|
audio_io_kwargs: Additional kwargs passed to handle audio IO.
|
|
"""
|
|
media_io_kwargs = None if not audio_io_kwargs else {
|
|
"audio": audio_io_kwargs
|
|
}
|
|
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
|
return media_connector.fetch_audio(audio_url)
|
|
|
|
|
|
def fetch_image(
|
|
image_url: str,
|
|
image_io_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> Image.Image:
|
|
"""
|
|
Args:
|
|
image_url: URL of the image file to fetch.
|
|
image_io_kwargs: Additional kwargs passed to handle image IO.
|
|
"""
|
|
media_io_kwargs = None if not image_io_kwargs else {
|
|
"image": image_io_kwargs
|
|
}
|
|
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
|
return media_connector.fetch_image(image_url)
|
|
|
|
|
|
def fetch_video(
|
|
video_url: str,
|
|
video_io_kwargs: Optional[dict[str, Any]] = None,
|
|
) -> tuple[npt.NDArray, dict[str, Any]]:
|
|
"""
|
|
Args:
|
|
video_url: URL of the video file to fetch.
|
|
video_io_kwargs: Additional kwargs passed to handle video IO.
|
|
"""
|
|
media_io_kwargs = None if not video_io_kwargs else {
|
|
"video": video_io_kwargs
|
|
}
|
|
media_connector = MediaConnector(media_io_kwargs=media_io_kwargs)
|
|
return media_connector.fetch_video(video_url)
|