vllm/vllm/multimodal/utils.py
Cyrus Leung 7af553ea30
[Misc] Abstract the logic for reading and writing media content (#11527)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
2024-12-27 19:21:23 +08:00

436 lines
14 KiB
Python

from functools import lru_cache
from pathlib import Path
from typing import Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO
logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer)
_M = TypeVar("_M")
class MediaConnector:
def __init__(
self,
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
) -> None:
super().__init__()
self.connection = connection
if allowed_local_media_path:
allowed_local_media_path_ = Path(allowed_local_media_path)
if not allowed_local_media_path_.exists():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} does not exist.")
if not allowed_local_media_path_.is_dir():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} must be a directory.")
else:
allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_
def _load_data_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M:
data_spec, data = url_spec.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
return media_io.load_base64(media_type, data)
def _load_file_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M:
allowed_local_media_path = self.allowed_local_media_path
if allowed_local_media_path is None:
raise RuntimeError("Cannot load local files without "
"`--allowed-local-media-path`.")
filepath = Path(url_spec.path)
if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError(
f"The file path {filepath} must be a subpath "
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
return media_io.load_file(filepath)
def load_from_url(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M:
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
async def load_from_url_async(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M:
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
def fetch_audio(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
audio_io = AudioMediaIO()
return self.load_from_url(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
async def fetch_audio_async(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
audio_io = AudioMediaIO()
return await self.load_from_url_async(
audio_url,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
def fetch_image(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return self.load_from_url(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
async def fetch_image_async(
self,
image_url: str,
*,
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return await self.load_from_url_async(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
def fetch_video(
self,
video_url: str,
*,
image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
image_io = ImageMediaIO(image_mode=image_mode)
video_io = VideoMediaIO(image_io, num_frames=num_frames)
return self.load_from_url(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def fetch_video_async(
self,
video_url: str,
*,
image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
video_io = VideoMediaIO(image_io, num_frames=num_frames)
return await self.load_from_url_async(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM."""
fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image
fetch_video = global_media_connector.fetch_video
def encode_audio_base64(
audio: np.ndarray,
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
audio_io = AudioMediaIO()
return audio_io.encode_base64((audio, sampling_rate))
def encode_image_base64(
image: Image.Image,
*,
image_mode: str = "RGB",
format: str = "JPEG",
) -> str:
"""
Encode a pillow image to base64 format.
By default, the image is converted into RGB format before being encoded.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return image_io.encode_base64(image, image_format=format)
def encode_video_base64(frames: npt.NDArray) -> str:
image_io = ImageMediaIO()
video_io = VideoMediaIO(image_io)
return video_io.encode_base64(frames)
def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset = max_possible_layers - len(encoder_outputs)
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
# Utilities for input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> list[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
def repeat_and_pad_placeholder_tokens(
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: list[int],
*,
placeholder_token_id: int,
repeat_count: Union[int, list[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
if prompt is None:
new_prompt = None
else:
placeholder_token_str = tokenizer.decode(placeholder_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
placeholder_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases
if placeholder_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", placeholder_token_str)
if placeholder_token_count < len(repeat_count):
logger.warning(
"The number of multi-modal placeholder tokens in the prompt "
"is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text")
repeat_count = repeat_count[:placeholder_token_count]
prompt_parts = prompt.split(placeholder_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i, repeat_count_item in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count_item,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]
new_token_ids = list[int]()
placeholder_ranges = list[PlaceholderRange]()
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
replacement_ids = repeat_and_pad_token(
placeholder_token_id,
repeat_count=repeat_count[placeholder_token_idx],
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
# No need to further scan the list since we replaced all tokens
if placeholder_token_idx >= len(repeat_count):
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
return new_prompt, new_token_ids, placeholder_ranges
def consecutive_placeholder_ranges(
num_items: int,
item_size: int,
initial_offset: int = 0) -> list[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [
PlaceholderRange(offset=initial_offset + i * item_size,
length=item_size) for i in range(num_items)
]