[Misc] Abstract the logic for reading and writing media content (#11527)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-27 19:21:23 +08:00 committed by GitHub
parent 2c9b8ea2b0
commit 7af553ea30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 495 additions and 389 deletions

View File

@ -33,6 +33,7 @@ class MockModelConfig:
hf_config = MockHFConfig()
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}

View File

@ -2,7 +2,6 @@ import warnings
from typing import Optional
import pytest
from PIL import Image
from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input(
image_data = mm_data.get("image")
assert image_data is not None
if image_count == 1:
assert isinstance(image_data, Image.Image)
else:
assert isinstance(image_data, list) and len(image_data) == image_count
assert isinstance(image_data, list) and len(image_data) == image_count
def test_parse_chat_messages_single_image(

View File

@ -9,7 +9,7 @@ import pytest
from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
from vllm.multimodal.utils import (MediaConnector,
repeat_and_pad_placeholder_tokens)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
@ -23,7 +23,12 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module")
def url_images() -> Dict[str, Image.Image]:
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
connector = MediaConnector()
return {
image_url: connector.fetch_image(image_url)
for image_url in TEST_IMAGE_URLS
}
def get_supported_suffixes() -> Tuple[str, ...]:
@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url)
image_async = await async_fetch_image(image_url)
connector = MediaConnector()
image_sync = connector.fetch_image(image_url)
image_async = await connector.fetch_image_async(image_url)
assert _image_equals(image_sync, image_async)
@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
image_url: str, suffix: str):
connector = MediaConnector()
url_image = url_images[image_url]
try:
@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image = base64.b64encode(f.read()).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"
data_image_sync = fetch_image(data_url)
data_image_sync = connector.fetch_image(data_url)
if _image_equals(url_image, Image.open(f)):
assert _image_equals(url_image, data_image_sync)
else:
pass # Lossy format; only check that image can be opened
data_image_async = await async_fetch_image(data_url)
data_image_async = await connector.fetch_image_async(data_url)
assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str):
connector = MediaConnector()
with TemporaryDirectory() as temp_dir:
origin_image = fetch_image(image_url)
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
origin_image = connector.fetch_image(image_url)
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
quality=100,
icc_profile=origin_image.info.get('icc_profile'))
image_async = await async_fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
image_sync = fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
image_async = await local_connector.fetch_image_async(
f"file://{temp_dir}/{os.path.basename(image_url)}")
image_sync = local_connector.fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}")
# Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox()
with pytest.raises(ValueError):
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
await async_fetch_image(
with pytest.raises(ValueError, match="must be a subpath"):
await local_connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(RuntimeError, match="Cannot load local files"):
await connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError, match="must be a subpath"):
local_connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(RuntimeError, match="Cannot load local files"):
connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])

View File

@ -21,12 +21,10 @@ class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]
@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
y, sr = librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr
return librosa.load(audio_path, sr=None)
@property
def url(self) -> str:

View File

@ -6,7 +6,7 @@ from collections import defaultdict, deque
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
Literal, Optional, Tuple, TypeVar, Union, cast)
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@ -31,11 +33,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.multimodal.utils import MediaConnector
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once
@ -368,14 +366,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._items: List[_T] = []
self._items_by_modality = defaultdict[str, list[_T]](list)
@property
def model_config(self) -> ModelConfig:
return self._model_config
@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
@ -435,38 +436,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else:
raise TypeError(f"Unknown modality: {modality}")
@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
"""
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")
self._consumed_items[modality] = current_count
self._items.append(item)
self._items_by_modality[modality].append(item)
return self._placeholder_str(modality, current_count)
@ -475,22 +457,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None
if self._items_by_modality:
return dict(self._items_by_modality)
return None
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
class AsyncMultiModalItemTracker(
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items:
items = await asyncio.gather(*self._items)
return self._combine(items)
if self._items_by_modality:
return {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
return None
@ -522,7 +508,7 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError
@abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
def parse_input_audio(self, input_audio: InputAudio) -> None:
raise NotImplementedError
@abstractmethod
@ -537,31 +523,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url,
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)
image = self._connector.fetch_image(image_url)
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url)
audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio = get_and_parse_audio(audio_url)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder)
return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url)
video = self._connector.fetch_video(video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
@ -573,33 +559,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__()
self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image(
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)
image_coro = self._connector.fetch_image_async(image_url)
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url)
audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
input_audio_data = input_audio.get("data","")
input_audio_format = input_audio.get("format","")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)
def parse_input_audio(self, input_audio: InputAudio) -> None:
audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder)
return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url)
video = self._connector.fetch_video_async(video_url)
placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder)
@ -695,10 +679,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str,
Callable[[ChatCompletionContentPartParam],
Union[str, Dict[str,str]]]] = {
MM_PARSER_MAP: Dict[
str,
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text":
lambda part: _TextParser(part).get("text", ""),
"image_url":
@ -715,8 +702,7 @@ MM_PARSER_MAP: Dict[str,
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str,
Union[str, Dict[str, str]]]:
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
"""
Parses a given multi-modal content part based on its type.
@ -783,7 +769,7 @@ def _parse_chat_message_content_parts(
*,
wrap_dicts: bool,
) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = []
content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser()
@ -814,7 +800,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser,
*,
wrap_dicts: bool,
) -> Optional[Union[str, Dict[str, str]]]:
) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
@ -823,8 +809,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders.
"""
if isinstance(part, str): # Handle plain text parts
text = _TextParser(part)
return text
return part
# Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part)
@ -855,7 +840,7 @@ def _parse_chat_message_content_part(
return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio":
dict_content = cast(Dict[str, str], content)
dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None

View File

@ -1,10 +1,14 @@
import base64
from io import BytesIO
from pathlib import Path
import numpy as np
import numpy.typing as npt
from vllm.inputs.registry import InputContext
from vllm.utils import PlaceholderModule
from .base import MultiModalPlugin
from .base import MediaIO, MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
try:
@ -12,6 +16,11 @@ try:
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data."""
@ -39,3 +48,28 @@ def resample_audio(
target_sr: float,
) -> npt.NDArray[np.floating]:
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)
def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format="WAV")
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
Optional, Sequence, Tuple, Type, TypeVar, Union)
from torch import nn
@ -118,7 +119,7 @@ class MultiModalPlugin(ABC):
self,
model_config: "ModelConfig",
data: MultiModalData[Any],
mm_processor_kwargs: Optional[Dict[str, Any]],
mm_processor_kwargs: Optional[dict[str, Any]],
) -> MultiModalKwargs:
"""
Transform the data into a dictionary of model inputs using the
@ -254,10 +255,10 @@ class MultiModalPlaceholderMap:
"""
class IndexMap(NamedTuple):
src: List[int]
dest: List[int]
src: list[int]
dest: list[int]
src_ranges: List[range]
src_ranges: list[range]
"""
The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``.
@ -268,7 +269,7 @@ class MultiModalPlaceholderMap:
The total number of flattened multi-modal embeddings.
"""
dest_ranges: List[range]
dest_ranges: list[range]
"""
The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings.
@ -288,7 +289,7 @@ class MultiModalPlaceholderMap:
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> Tuple[Optional[MultiModalDataDict], Dict[str,
) -> Tuple[Optional[MultiModalDataDict], dict[str,
"MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
@ -376,9 +377,9 @@ class MultiModalPlaceholderMap:
def append_items_from_seq_group(
self,
positions: range,
multi_modal_items: List[_T],
multi_modal_items: list[_T],
multi_modal_placeholders: Sequence[PlaceholderRange],
) -> List[_T]:
) -> list[_T]:
"""
Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items.
@ -454,3 +455,22 @@ class MultiModalPlaceholderMap:
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
dest=dest_indices)
class MediaIO(ABC, Generic[_T]):
@abstractmethod
def load_bytes(self, data: bytes) -> _T:
raise NotImplementedError
@abstractmethod
def load_base64(self, media_type: str, data: str) -> _T:
"""
List of media types:
https://www.iana.org/assignments/media-types/media-types.xhtml
"""
raise NotImplementedError
@abstractmethod
def load_file(self, filepath: Path) -> _T:
raise NotImplementedError

View File

@ -1,4 +1,7 @@
import base64
from functools import lru_cache
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
import torch
@ -9,7 +12,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of
from .base import MultiModalPlugin
from .base import MediaIO, MultiModalPlugin
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
if TYPE_CHECKING:
@ -96,3 +99,39 @@ def rescale_image_size(image: Image.Image,
if transpose >= 0:
image = image.transpose(Image.Transpose(transpose))
return image
class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, *, image_mode: str = "RGB") -> None:
super().__init__()
self.image_mode = image_mode
def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data))
image.load()
return image.convert(self.image_mode)
def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return image.convert(self.image_mode)
def encode_base64(
self,
media: Image.Image,
*,
image_format: str = "JPEG",
) -> str:
image = media
with BytesIO() as buffer:
image = image.convert(self.image_mode)
image.save(buffer, image_format)
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')

View File

@ -1,8 +1,7 @@
import base64
import os
from functools import lru_cache
from io import BytesIO
from typing import List, Optional, Tuple, TypeVar, Union
from pathlib import Path
from typing import Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse
import numpy as np
import numpy.typing as npt
@ -10,283 +9,246 @@ import torch
from PIL import Image
import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.utils import PlaceholderModule
from .inputs import MultiModalDataDict, PlaceholderRange
try:
import decord
except ImportError:
decord = PlaceholderModule("decord") # type: ignore[assignment]
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
from .audio import AudioMediaIO
from .base import MediaIO
from .image import ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO
logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer)
def _load_image_from_bytes(b: bytes) -> Image.Image:
image = Image.open(BytesIO(b))
image.load()
return image
_M = TypeVar("_M")
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
# Get the common path
common_path = os.path.commonpath([
os.path.abspath(image_path),
os.path.abspath(allowed_local_media_path)
])
# Check if the common path is the same as allowed_local_media_path
return common_path == os.path.abspath(allowed_local_media_path)
class MediaConnector:
def __init__(
self,
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
) -> None:
super().__init__()
def _load_image_from_file(image_url: str,
allowed_local_media_path: str) -> Image.Image:
if not allowed_local_media_path:
raise ValueError("Invalid 'image_url': Cannot load local files without"
"'--allowed-local-media-path'.")
if allowed_local_media_path:
if not os.path.exists(allowed_local_media_path):
self.connection = connection
if allowed_local_media_path:
allowed_local_media_path_ = Path(allowed_local_media_path)
if not allowed_local_media_path_.exists():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} does not exist.")
if not allowed_local_media_path_.is_dir():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} must be a directory.")
else:
allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_
def _load_data_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M:
data_spec, data = url_spec.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
return media_io.load_base64(media_type, data)
def _load_file_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M:
allowed_local_media_path = self.allowed_local_media_path
if allowed_local_media_path is None:
raise RuntimeError("Cannot load local files without "
"`--allowed-local-media-path`.")
filepath = Path(url_spec.path)
if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} does not exist.")
if not os.path.isdir(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} must be a directory.")
f"The file path {filepath} must be a subpath "
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
# Only split once and assume the second part is the image path
_, image_path = image_url.split("file://", 1)
if not _is_subpath(image_path, allowed_local_media_path):
raise ValueError(
f"Invalid 'image_url': The file path {image_path} must"
" be a subpath of '--allowed-local-media-path'"
f" '{allowed_local_media_path}'.")
return media_io.load_file(filepath)
image = Image.open(image_path)
image.load()
return image
def load_from_url(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M:
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout)
def _load_image_from_data_url(image_url: str) -> Image.Image:
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)
return media_io.load_bytes(data)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
def fetch_image(image_url: str,
*,
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
image_raw = global_http_connection.get_bytes(
image_url,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
image = _load_image_from_bytes(image_raw)
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image', 'file://' or 'http'.")
async def load_from_url_async(
self,
url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M:
url_spec = urlparse(url)
return image.convert(image_mode)
if url_spec.scheme.startswith("http"):
connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
return media_io.load_bytes(data)
async def async_fetch_image(image_url: str,
*,
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
By default, the image is converted into RGB format.
"""
if image_url.startswith('http'):
image_raw = await global_http_connection.async_get_bytes(
image_url,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
image = _load_image_from_bytes(image_raw)
if url_spec.scheme == "file":
return self._load_file_url(url_spec, media_io)
elif image_url.startswith('data:image'):
image = _load_image_from_data_url(image_url)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image', 'file://' or 'http'.")
msg = "The URL must be either a HTTP, data or file URL."
raise ValueError(msg)
return image.convert(image_mode)
def fetch_audio(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
audio_io = AudioMediaIO()
def _load_video_from_bytes(b: bytes, num_frames: int = 32) -> npt.NDArray:
video_path = BytesIO(b)
vr = decord.VideoReader(video_path, num_threads=1)
total_frame_num = len(vr)
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = [i for i in range(0, total_frame_num)]
frames = vr.get_batch(frame_idx).asnumpy()
return frames
def _load_video_from_data_url(video_url: str) -> npt.NDArray:
# Only split once and assume the second part is the base64 encoded video
_, video_base64 = video_url.split(",", 1)
if video_url.startswith("data:video/jpeg;"):
return np.stack([
np.array(load_image_from_base64(frame_base64))
for frame_base64 in video_base64.split(",")
])
return load_video_from_base64(video_base64)
def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = global_http_connection.get_bytes(
video_url,
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video
async def async_fetch_video(video_url: str,
*,
num_frames: int = 32) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = await global_http_connection.async_get_bytes(
video_url,
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
if audio_url.startswith("http"):
audio_bytes = global_http_connection.get_bytes(
return self.load_from_url(
audio_url,
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64)
else:
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
"with either 'data:audio' or 'http'.")
return librosa.load(BytesIO(audio_bytes), sr=None)
async def fetch_audio_async(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
audio_io = AudioMediaIO()
async def async_fetch_audio(
audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
if audio_url.startswith("http"):
audio_bytes = await global_http_connection.async_get_bytes(
return await self.load_from_url_async(
audio_url,
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
)
elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64)
else:
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
"with either 'data:audio' or 'http'.")
return librosa.load(BytesIO(audio_bytes), sr=None)
def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = fetch_audio(audio_url)
return {"audio": (audio, sr)}
def get_and_parse_image(
def fetch_image(
self,
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = fetch_image(image_url,
allowed_local_media_path=allowed_local_media_path)
return {"image": image}
image_mode: str = "RGB",
) -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
def get_and_parse_video(video_url: str) -> MultiModalDataDict:
video = fetch_video(video_url)
return {"video": video}
return self.load_from_url(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = await async_fetch_audio(audio_url)
return {"audio": (audio, sr)}
async def async_get_and_parse_image(
async def fetch_image_async(
self,
image_url: str,
*,
allowed_local_media_path: str = "") -> MultiModalDataDict:
image = await async_fetch_image(
image_url, allowed_local_media_path=allowed_local_media_path)
return {"image": image}
image_mode: str = "RGB",
) -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return await self.load_from_url_async(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
def fetch_video(
self,
video_url: str,
*,
image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
image_io = ImageMediaIO(image_mode=image_mode)
video_io = VideoMediaIO(image_io, num_frames=num_frames)
return self.load_from_url(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def fetch_video_async(
self,
video_url: str,
*,
image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
video_io = VideoMediaIO(image_io, num_frames=num_frames)
return await self.load_from_url_async(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict:
video = await async_fetch_video(video_url)
return {"video": video}
global_media_connector = MediaConnector()
"""The global :class:`MediaConnector` instance used by vLLM."""
fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image
fetch_video = global_media_connector.fetch_video
def encode_audio_base64(
@ -294,10 +256,8 @@ def encode_audio_base64(
sampling_rate: int,
) -> str:
"""Encode audio as base64."""
buffered = BytesIO()
soundfile.write(buffered, audio, sampling_rate, format="WAV")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
audio_io = AudioMediaIO()
return audio_io.encode_base64((audio, sampling_rate))
def encode_image_base64(
@ -311,29 +271,14 @@ def encode_image_base64(
By default, the image is converted into RGB format before being encoded.
"""
buffered = BytesIO()
image = image.convert(image_mode)
image.save(buffered, format)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
"""Load image from base64 format."""
return _load_image_from_bytes(base64.b64decode(image))
image_io = ImageMediaIO(image_mode=image_mode)
return image_io.encode_base64(image, image_format=format)
def encode_video_base64(frames: npt.NDArray) -> str:
base64_frames = []
frames_list = [frames[i] for i in range(frames.shape[0])]
for frame in frames_list:
img_base64 = encode_image_base64(Image.fromarray(frame))
base64_frames.append(img_base64)
return ",".join(base64_frames)
def load_video_from_base64(video: Union[bytes, str]) -> npt.NDArray:
"""Load video from base64 format."""
return _load_video_from_bytes(base64.b64decode(video))
image_io = ImageMediaIO()
video_io = VideoMediaIO(image_io)
return video_io.encode_base64(frames)
def resolve_visual_encoder_outputs(
@ -389,7 +334,7 @@ def repeat_and_pad_token(
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> List[_T]:
) -> list[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
@ -402,13 +347,13 @@ def repeat_and_pad_token(
def repeat_and_pad_placeholder_tokens(
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: List[int],
prompt_token_ids: list[int],
*,
placeholder_token_id: int,
repeat_count: Union[int, List[int]],
repeat_count: Union[int, list[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]:
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
@ -450,8 +395,8 @@ def repeat_and_pad_placeholder_tokens(
new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]
new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
new_token_ids = list[int]()
placeholder_ranges = list[PlaceholderRange]()
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
@ -481,7 +426,7 @@ def repeat_and_pad_placeholder_tokens(
def consecutive_placeholder_ranges(
num_items: int,
item_size: int,
initial_offset: int = 0) -> List[PlaceholderRange]:
initial_offset: int = 0) -> list[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [

View File

@ -1,23 +1,32 @@
from functools import lru_cache
import base64
from functools import lru_cache, partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
import cv2
import numpy as np
import numpy.typing as npt
from PIL import Image
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of
from vllm.utils import PlaceholderModule, is_list_of
from .base import MultiModalData
from .image import ImagePlugin
from .base import MediaIO, MultiModalData
from .image import ImageMediaIO, ImagePlugin
from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING:
from vllm.config import ModelConfig
try:
import decord
except ImportError:
decord = PlaceholderModule("decord") # type: ignore[assignment]
logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor)
@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray,
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...]
return sampled_frames
class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__(
self,
image_io: ImageMediaIO,
*,
num_frames: int = 32,
) -> None:
super().__init__()
self.image_io = image_io
self.num_frames = num_frames
def load_bytes(self, data: bytes) -> npt.NDArray:
vr = decord.VideoReader(BytesIO(data), num_threads=1)
total_frame_num = len(vr)
num_frames = self.num_frames
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = list(range(0, total_frame_num))
return vr.get_batch(frame_idx).asnumpy()
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
if media_type.lower() == "video/jpeg":
load_frame = partial(
self.image_io.load_base64,
"image/jpeg",
)
return np.stack([
np.array(load_frame(frame_data))
for frame_data in data.split(",")
])
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> npt.NDArray:
with filepath.open("rb") as f:
data = f.read()
return self.load_bytes(data)
def encode_base64(
self,
media: npt.NDArray,
*,
video_format: str = "JPEG",
) -> str:
video = media
if video_format == "JPEG":
encode_frame = partial(
self.image_io.encode_base64,
image_format=video_format,
)
return ",".join(
encode_frame(Image.fromarray(frame)) for frame in video)
msg = "Only JPEG format is supported for now."
raise NotImplementedError(msg)