[Misc] Abstract the logic for reading and writing media content (#11527)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-12-27 19:21:23 +08:00 committed by GitHub
parent 2c9b8ea2b0
commit 7af553ea30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 495 additions and 389 deletions

View File

@ -33,6 +33,7 @@ class MockModelConfig:
hf_config = MockHFConfig() hf_config = MockHFConfig()
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}

View File

@ -2,7 +2,6 @@ import warnings
from typing import Optional from typing import Optional
import pytest import pytest
from PIL import Image
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input(
image_data = mm_data.get("image") image_data = mm_data.get("image")
assert image_data is not None assert image_data is not None
if image_count == 1: assert isinstance(image_data, list) and len(image_data) == image_count
assert isinstance(image_data, Image.Image)
else:
assert isinstance(image_data, list) and len(image_data) == image_count
def test_parse_chat_messages_single_image( def test_parse_chat_messages_single_image(

View File

@ -9,7 +9,7 @@ import pytest
from PIL import Image, ImageChops from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import (async_fetch_image, fetch_image, from vllm.multimodal.utils import (MediaConnector,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
@ -23,7 +23,12 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def url_images() -> Dict[str, Image.Image]: def url_images() -> Dict[str, Image.Image]:
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS} connector = MediaConnector()
return {
image_url: connector.fetch_image(image_url)
for image_url in TEST_IMAGE_URLS
}
def get_supported_suffixes() -> Tuple[str, ...]: def get_supported_suffixes() -> Tuple[str, ...]:
@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str): async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url) connector = MediaConnector()
image_async = await async_fetch_image(image_url)
image_sync = connector.fetch_image(image_url)
image_async = await connector.fetch_image_async(image_url)
assert _image_equals(image_sync, image_async) assert _image_equals(image_sync, image_async)
@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes()) @pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image], async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
image_url: str, suffix: str): image_url: str, suffix: str):
connector = MediaConnector()
url_image = url_images[image_url] url_image = url_images[image_url]
try: try:
@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image = base64.b64encode(f.read()).decode("utf-8") base64_image = base64.b64encode(f.read()).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}" data_url = f"data:{mime_type};base64,{base64_image}"
data_image_sync = fetch_image(data_url) data_image_sync = connector.fetch_image(data_url)
if _image_equals(url_image, Image.open(f)): if _image_equals(url_image, Image.open(f)):
assert _image_equals(url_image, data_image_sync) assert _image_equals(url_image, data_image_sync)
else: else:
pass # Lossy format; only check that image can be opened pass # Lossy format; only check that image can be opened
data_image_async = await async_fetch_image(data_url) data_image_async = await connector.fetch_image_async(data_url)
assert _image_equals(data_image_sync, data_image_async) assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str): async def test_fetch_image_local_files(image_url: str):
connector = MediaConnector()
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:
origin_image = fetch_image(image_url) local_connector = MediaConnector(allowed_local_media_path=temp_dir)
origin_image = connector.fetch_image(image_url)
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)), origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
quality=100, quality=100,
icc_profile=origin_image.info.get('icc_profile')) icc_profile=origin_image.info.get('icc_profile'))
image_async = await async_fetch_image( image_async = await local_connector.fetch_image_async(
f"file://{temp_dir}/{os.path.basename(image_url)}", f"file://{temp_dir}/{os.path.basename(image_url)}")
allowed_local_media_path=temp_dir) image_sync = local_connector.fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}")
image_sync = fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
# Check that the images are equal # Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox() assert not ImageChops.difference(image_sync, image_async).getbbox()
with pytest.raises(ValueError): with pytest.raises(ValueError, match="must be a subpath"):
await async_fetch_image( await local_connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}", f"file://{temp_dir}/../{os.path.basename(image_url)}")
allowed_local_media_path=temp_dir) with pytest.raises(RuntimeError, match="Cannot load local files"):
with pytest.raises(ValueError): await connector.fetch_image_async(
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}") f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError): with pytest.raises(ValueError, match="must be a subpath"):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}", local_connector.fetch_image(
allowed_local_media_path=temp_dir) f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError): with pytest.raises(RuntimeError, match="Cannot load local files"):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}") connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])

View File

@ -21,12 +21,10 @@ class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"] name: Literal["winning_call", "mary_had_lamb"]
@property @property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]: def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR) s3_prefix=ASSET_DIR)
y, sr = librosa.load(audio_path, sr=None) return librosa.load(audio_path, sr=None)
assert isinstance(sr, int)
return y, sr
@property @property
def url(self) -> str: def url(self) -> str:

View File

@ -6,7 +6,7 @@ from collections import defaultdict, deque
from functools import lru_cache, partial from functools import lru_cache, partial
from pathlib import Path from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) Literal, Optional, Tuple, TypeVar, Union, cast)
import jinja2.nodes import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils import transformers.utils.chat_template_utils as hf_chat_utils
@ -23,6 +23,8 @@ from openai.types.chat import (
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
from openai.types.chat import (ChatCompletionMessageToolCallParam, from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam) ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions # pydantic needs the TypedDict from typing_extensions
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@ -31,11 +33,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import (async_get_and_parse_audio, from vllm.multimodal.utils import MediaConnector
async_get_and_parse_image,
async_get_and_parse_video,
get_and_parse_audio, get_and_parse_image,
get_and_parse_video)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
@ -368,14 +366,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {}) if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._items: List[_T] = [] self._items_by_modality = defaultdict[str, list[_T]](list)
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
return self._model_config return self._model_config
@property
def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path
@staticmethod @staticmethod
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
@ -435,38 +436,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
else: else:
raise TypeError(f"Unknown modality: {modality}") raise TypeError(f"Unknown modality: {modality}")
@staticmethod
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
mm_lists: Mapping[str, List[object]] = defaultdict(list)
# Merge all the multi-modal items
for single_mm_data in items:
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)
# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}
def add(self, modality: ModalityStr, item: _T) -> Optional[str]: def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
""" """
Add a multi-modal item to the current prompt and returns the Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any. placeholder string to use, if any.
""" """
allowed_count = self._allowed_items.get(modality, 1) allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1 current_count = len(self._items_by_modality[modality]) + 1
if current_count > allowed_count: if current_count > allowed_count:
raise ValueError( raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in " f"At most {allowed_count} {modality}(s) may be provided in "
"one request.") "one request.")
self._consumed_items[modality] = current_count self._items_by_modality[modality].append(item)
self._items.append(item)
return self._placeholder_str(modality, current_count) return self._placeholder_str(modality, current_count)
@ -475,22 +457,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
raise NotImplementedError raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]: def all_mm_data(self) -> Optional[MultiModalDataDict]:
return self._combine(self._items) if self._items else None if self._items_by_modality:
return dict(self._items_by_modality)
return None
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self) return MultiModalContentParser(self)
class AsyncMultiModalItemTracker( class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]: async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if self._items: if self._items_by_modality:
items = await asyncio.gather(*self._items) return {
return self._combine(items) modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
return None return None
@ -522,7 +508,7 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -537,31 +523,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
image = get_and_parse_image(image_url, image = self._connector.fetch_image(image_url)
allowed_local_media_path=self._tracker.
_model_config.allowed_local_media_path)
placeholder = self._tracker.add("image", image) placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio = get_and_parse_audio(audio_url) audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio) placeholder = self._tracker.add("audio", audio)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
input_audio_data = input_audio.get("data","") audio_data = input_audio.get("data", "")
input_audio_format = input_audio.get("format","") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
audio = get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio) return self.parse_audio(audio_url)
self._add_placeholder(placeholder)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str) -> None:
video = get_and_parse_video(video_url) video = self._connector.fetch_video(video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
@ -573,33 +559,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
image_coro = async_get_and_parse_image( image_coro = self._connector.fetch_image_async(image_url)
image_url,
allowed_local_media_path=self._tracker._model_config.
allowed_local_media_path)
placeholder = self._tracker.add("image", image_coro) placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str) -> None:
audio_coro = async_get_and_parse_audio(audio_url) audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro) placeholder = self._tracker.add("audio", audio_coro)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
def parse_input_audio(self, input_audio: Dict[str, str]) -> None: def parse_input_audio(self, input_audio: InputAudio) -> None:
input_audio_data = input_audio.get("data","") audio_data = input_audio.get("data", "")
input_audio_format = input_audio.get("format","") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
audio_coro = async_get_and_parse_audio(audio_url)
placeholder = self._tracker.add("audio", audio_coro) return self.parse_audio(audio_url)
self._add_placeholder(placeholder)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str) -> None:
video = async_get_and_parse_video(video_url) video = self._connector.fetch_video_async(video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
@ -695,10 +679,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) _VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
# Define a mapping from part types to their corresponding parsing functions. # Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, MM_PARSER_MAP: Dict[
Callable[[ChatCompletionContentPartParam], str,
Union[str, Dict[str,str]]]] = { Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text": "text":
lambda part: _TextParser(part).get("text", ""), lambda part: _TextParser(part).get("text", ""),
"image_url": "image_url":
@ -715,8 +702,7 @@ MM_PARSER_MAP: Dict[str,
def _parse_chat_message_content_mm_part( def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
Union[str, Dict[str, str]]]:
""" """
Parses a given multi-modal content part based on its type. Parses a given multi-modal content part based on its type.
@ -783,7 +769,7 @@ def _parse_chat_message_content_parts(
*, *,
wrap_dicts: bool, wrap_dicts: bool,
) -> List[ConversationMessage]: ) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = [] content = list[_ContentPart]()
mm_parser = mm_tracker.create_parser() mm_parser = mm_tracker.create_parser()
@ -814,7 +800,7 @@ def _parse_chat_message_content_part(
mm_parser: BaseMultiModalContentParser, mm_parser: BaseMultiModalContentParser,
*, *,
wrap_dicts: bool, wrap_dicts: bool,
) -> Optional[Union[str, Dict[str, str]]]: ) -> Optional[_ContentPart]:
"""Parses a single part of a conversation. If wrap_dicts is True, """Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
@ -823,8 +809,7 @@ def _parse_chat_message_content_part(
with multimodal placeholders. with multimodal placeholders.
""" """
if isinstance(part, str): # Handle plain text parts if isinstance(part, str): # Handle plain text parts
text = _TextParser(part) return part
return text
# Handle structured dictionary parts # Handle structured dictionary parts
part_type, content = _parse_chat_message_content_mm_part(part) part_type, content = _parse_chat_message_content_mm_part(part)
@ -855,7 +840,7 @@ def _parse_chat_message_content_part(
return {'type': 'audio'} if wrap_dicts else None return {'type': 'audio'} if wrap_dicts else None
if part_type == "input_audio": if part_type == "input_audio":
dict_content = cast(Dict[str, str], content) dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content) mm_parser.parse_input_audio(dict_content)
return {'type': 'audio'} if wrap_dicts else None return {'type': 'audio'} if wrap_dicts else None

View File

@ -1,10 +1,14 @@
import base64
from io import BytesIO
from pathlib import Path
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.utils import PlaceholderModule from vllm.utils import PlaceholderModule
from .base import MultiModalPlugin from .base import MediaIO, MultiModalPlugin
from .inputs import AudioItem, MultiModalData, MultiModalKwargs from .inputs import AudioItem, MultiModalData, MultiModalKwargs
try: try:
@ -12,6 +16,11 @@ try:
except ImportError: except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment] librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
class AudioPlugin(MultiModalPlugin): class AudioPlugin(MultiModalPlugin):
"""Plugin for audio data.""" """Plugin for audio data."""
@ -39,3 +48,28 @@ def resample_audio(
target_sr: float, target_sr: float,
) -> npt.NDArray[np.floating]: ) -> npt.NDArray[np.floating]:
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None)
def load_base64(
self,
media_type: str,
data: str,
) -> tuple[npt.NDArray, float]:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
return librosa.load(filepath, sr=None)
def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
audio, sr = media
with BytesIO() as buffer:
soundfile.write(buffer, audio, sr, format="WAV")
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
Optional, Sequence, Tuple, Type, TypeVar, Union) Optional, Sequence, Tuple, Type, TypeVar, Union)
from torch import nn from torch import nn
@ -118,7 +119,7 @@ class MultiModalPlugin(ABC):
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
data: MultiModalData[Any], data: MultiModalData[Any],
mm_processor_kwargs: Optional[Dict[str, Any]], mm_processor_kwargs: Optional[dict[str, Any]],
) -> MultiModalKwargs: ) -> MultiModalKwargs:
""" """
Transform the data into a dictionary of model inputs using the Transform the data into a dictionary of model inputs using the
@ -254,10 +255,10 @@ class MultiModalPlaceholderMap:
""" """
class IndexMap(NamedTuple): class IndexMap(NamedTuple):
src: List[int] src: list[int]
dest: List[int] dest: list[int]
src_ranges: List[range] src_ranges: list[range]
""" """
The indices of the multi-modal embeddings that will replace the The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``. corresponding placeholder embeddings pointed to by ``dest_ranges``.
@ -268,7 +269,7 @@ class MultiModalPlaceholderMap:
The total number of flattened multi-modal embeddings. The total number of flattened multi-modal embeddings.
""" """
dest_ranges: List[range] dest_ranges: list[range]
""" """
The indices of the placeholder embeddings that will be replaced by the The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings. multimodal embeddings.
@ -288,7 +289,7 @@ class MultiModalPlaceholderMap:
@classmethod @classmethod
def from_seq_group( def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range cls, seq_group: "SequenceGroupMetadata", positions: range
) -> Tuple[Optional[MultiModalDataDict], Dict[str, ) -> Tuple[Optional[MultiModalDataDict], dict[str,
"MultiModalPlaceholderMap"]]: "MultiModalPlaceholderMap"]]:
""" """
Returns the multi-modal items that intersect with the portion of a Returns the multi-modal items that intersect with the portion of a
@ -376,9 +377,9 @@ class MultiModalPlaceholderMap:
def append_items_from_seq_group( def append_items_from_seq_group(
self, self,
positions: range, positions: range,
multi_modal_items: List[_T], multi_modal_items: list[_T],
multi_modal_placeholders: Sequence[PlaceholderRange], multi_modal_placeholders: Sequence[PlaceholderRange],
) -> List[_T]: ) -> list[_T]:
""" """
Adds the multi-modal items that intersect ```positions`` to this Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items. placeholder map and returns the intersecting items.
@ -454,3 +455,22 @@ class MultiModalPlaceholderMap:
return MultiModalPlaceholderMap.IndexMap(src=src_indices, return MultiModalPlaceholderMap.IndexMap(src=src_indices,
dest=dest_indices) dest=dest_indices)
class MediaIO(ABC, Generic[_T]):
@abstractmethod
def load_bytes(self, data: bytes) -> _T:
raise NotImplementedError
@abstractmethod
def load_base64(self, media_type: str, data: str) -> _T:
"""
List of media types:
https://www.iana.org/assignments/media-types/media-types.xhtml
"""
raise NotImplementedError
@abstractmethod
def load_file(self, filepath: Path) -> _T:
raise NotImplementedError

View File

@ -1,4 +1,7 @@
import base64
from functools import lru_cache from functools import lru_cache
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import torch import torch
@ -9,7 +12,7 @@ from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_image_processor from vllm.transformers_utils.processor import get_image_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .base import MultiModalPlugin from .base import MediaIO, MultiModalPlugin
from .inputs import ImageItem, MultiModalData, MultiModalKwargs from .inputs import ImageItem, MultiModalData, MultiModalKwargs
if TYPE_CHECKING: if TYPE_CHECKING:
@ -96,3 +99,39 @@ def rescale_image_size(image: Image.Image,
if transpose >= 0: if transpose >= 0:
image = image.transpose(Image.Transpose(transpose)) image = image.transpose(Image.Transpose(transpose))
return image return image
class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, *, image_mode: str = "RGB") -> None:
super().__init__()
self.image_mode = image_mode
def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data))
image.load()
return image.convert(self.image_mode)
def load_base64(self, media_type: str, data: str) -> Image.Image:
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> Image.Image:
image = Image.open(filepath)
image.load()
return image.convert(self.image_mode)
def encode_base64(
self,
media: Image.Image,
*,
image_format: str = "JPEG",
) -> str:
image = media
with BytesIO() as buffer:
image = image.convert(self.image_mode)
image.save(buffer, image_format)
data = buffer.getvalue()
return base64.b64encode(data).decode('utf-8')

View File

@ -1,8 +1,7 @@
import base64
import os
from functools import lru_cache from functools import lru_cache
from io import BytesIO from pathlib import Path
from typing import List, Optional, Tuple, TypeVar, Union from typing import Optional, TypeVar, Union
from urllib.parse import ParseResult, urlparse
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -10,283 +9,246 @@ import torch
from PIL import Image from PIL import Image
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import global_http_connection from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.utils import PlaceholderModule
from .inputs import MultiModalDataDict, PlaceholderRange from .audio import AudioMediaIO
from .base import MediaIO
try: from .image import ImageMediaIO
import decord from .inputs import PlaceholderRange
except ImportError: from .video import VideoMediaIO
decord = PlaceholderModule("decord") # type: ignore[assignment]
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile
except ImportError:
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer) cached_get_tokenizer = lru_cache(get_tokenizer)
_M = TypeVar("_M")
def _load_image_from_bytes(b: bytes) -> Image.Image:
image = Image.open(BytesIO(b))
image.load()
return image
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool: class MediaConnector:
# Get the common path
common_path = os.path.commonpath([
os.path.abspath(image_path),
os.path.abspath(allowed_local_media_path)
])
# Check if the common path is the same as allowed_local_media_path
return common_path == os.path.abspath(allowed_local_media_path)
def __init__(
self,
connection: HTTPConnection = global_http_connection,
*,
allowed_local_media_path: str = "",
) -> None:
super().__init__()
def _load_image_from_file(image_url: str, self.connection = connection
allowed_local_media_path: str) -> Image.Image:
if not allowed_local_media_path: if allowed_local_media_path:
raise ValueError("Invalid 'image_url': Cannot load local files without" allowed_local_media_path_ = Path(allowed_local_media_path)
"'--allowed-local-media-path'.")
if allowed_local_media_path: if not allowed_local_media_path_.exists():
if not os.path.exists(allowed_local_media_path): raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} does not exist.")
if not allowed_local_media_path_.is_dir():
raise ValueError(
"Invalid `--allowed-local-media-path`: The path "
f"{allowed_local_media_path_} must be a directory.")
else:
allowed_local_media_path_ = None
self.allowed_local_media_path = allowed_local_media_path_
def _load_data_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M:
data_spec, data = url_spec.path.split(",", 1)
media_type, data_type = data_spec.split(";", 1)
if data_type != "base64":
msg = "Only base64 data URLs are supported for now."
raise NotImplementedError(msg)
return media_io.load_base64(media_type, data)
def _load_file_url(
self,
url_spec: ParseResult,
media_io: MediaIO[_M],
) -> _M:
allowed_local_media_path = self.allowed_local_media_path
if allowed_local_media_path is None:
raise RuntimeError("Cannot load local files without "
"`--allowed-local-media-path`.")
filepath = Path(url_spec.path)
if allowed_local_media_path not in filepath.resolve().parents:
raise ValueError( raise ValueError(
"Invalid '--allowed-local-media-path': " f"The file path {filepath} must be a subpath "
f"The path {allowed_local_media_path} does not exist.") f"of `--allowed-local-media-path` {allowed_local_media_path}.")
if not os.path.isdir(allowed_local_media_path):
raise ValueError(
"Invalid '--allowed-local-media-path': "
f"The path {allowed_local_media_path} must be a directory.")
# Only split once and assume the second part is the image path return media_io.load_file(filepath)
_, image_path = image_url.split("file://", 1)
if not _is_subpath(image_path, allowed_local_media_path):
raise ValueError(
f"Invalid 'image_url': The file path {image_path} must"
" be a subpath of '--allowed-local-media-path'"
f" '{allowed_local_media_path}'.")
image = Image.open(image_path) def load_from_url(
image.load() self,
return image url: str,
media_io: MediaIO[_M],
*,
fetch_timeout: Optional[int] = None,
) -> _M:
url_spec = urlparse(url)
if url_spec.scheme.startswith("http"):
connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout)
def _load_image_from_data_url(image_url: str) -> Image.Image: return media_io.load_bytes(data)
# Only split once and assume the second part is the base64 encoded image
_, image_base64 = image_url.split(",", 1)
return load_image_from_base64(image_base64)
if url_spec.scheme == "data":
return self._load_data_url(url_spec, media_io)
def fetch_image(image_url: str, if url_spec.scheme == "file":
*, return self._load_file_url(url_spec, media_io)
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format. msg = "The URL must be either a HTTP, data or file URL."
""" raise ValueError(msg)
if image_url.startswith('http'):
image_raw = global_http_connection.get_bytes(
image_url,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'): async def load_from_url_async(
image = _load_image_from_data_url(image_url) self,
elif image_url.startswith('file://'): url: str,
image = _load_image_from_file(image_url, allowed_local_media_path) media_io: MediaIO[_M],
else: *,
raise ValueError("Invalid 'image_url': A valid 'image_url' must start " fetch_timeout: Optional[int] = None,
"with either 'data:image', 'file://' or 'http'.") ) -> _M:
url_spec = urlparse(url)
return image.convert(image_mode) if url_spec.scheme.startswith("http"):
connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
return media_io.load_bytes(data)
async def async_fetch_image(image_url: str, if url_spec.scheme == "data":
*, return self._load_data_url(url_spec, media_io)
image_mode: str = "RGB",
allowed_local_media_path: str = "") -> Image.Image:
"""
Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format. if url_spec.scheme == "file":
""" return self._load_file_url(url_spec, media_io)
if image_url.startswith('http'):
image_raw = await global_http_connection.async_get_bytes(
image_url,
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
image = _load_image_from_bytes(image_raw)
elif image_url.startswith('data:image'): msg = "The URL must be either a HTTP, data or file URL."
image = _load_image_from_data_url(image_url) raise ValueError(msg)
elif image_url.startswith('file://'):
image = _load_image_from_file(image_url, allowed_local_media_path)
else:
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
"with either 'data:image', 'file://' or 'http'.")
return image.convert(image_mode) def fetch_audio(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
audio_io = AudioMediaIO()
return self.load_from_url(
def _load_video_from_bytes(b: bytes, num_frames: int = 32) -> npt.NDArray:
video_path = BytesIO(b)
vr = decord.VideoReader(video_path, num_threads=1)
total_frame_num = len(vr)
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = [i for i in range(0, total_frame_num)]
frames = vr.get_batch(frame_idx).asnumpy()
return frames
def _load_video_from_data_url(video_url: str) -> npt.NDArray:
# Only split once and assume the second part is the base64 encoded video
_, video_base64 = video_url.split(",", 1)
if video_url.startswith("data:video/jpeg;"):
return np.stack([
np.array(load_image_from_base64(frame_base64))
for frame_base64 in video_base64.split(",")
])
return load_video_from_base64(video_base64)
def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = global_http_connection.get_bytes(
video_url,
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video
async def async_fetch_video(video_url: str,
*,
num_frames: int = 32) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
if video_url.startswith('http') or video_url.startswith('https'):
video_raw = await global_http_connection.async_get_bytes(
video_url,
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
video = _load_video_from_bytes(video_raw, num_frames)
elif video_url.startswith('data:video'):
video = _load_video_from_data_url(video_url)
else:
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
"with either 'data:video' or 'http'.")
return video
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Load audio from a URL.
"""
if audio_url.startswith("http"):
audio_bytes = global_http_connection.get_bytes(
audio_url, audio_url,
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
) )
elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64)
else:
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
"with either 'data:audio' or 'http'.")
return librosa.load(BytesIO(audio_bytes), sr=None) async def fetch_audio_async(
self,
audio_url: str,
) -> tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
audio_io = AudioMediaIO()
return await self.load_from_url_async(
async def async_fetch_audio(
audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
"""
Asynchronously fetch audio from a URL.
"""
if audio_url.startswith("http"):
audio_bytes = await global_http_connection.async_get_bytes(
audio_url, audio_url,
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT, audio_io,
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
) )
elif audio_url.startswith("data:audio"):
_, audio_base64 = audio_url.split(",", 1)
audio_bytes = base64.b64decode(audio_base64)
else:
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
"with either 'data:audio' or 'http'.")
return librosa.load(BytesIO(audio_bytes), sr=None) def fetch_image(
self,
def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
audio, sr = fetch_audio(audio_url)
return {"audio": (audio, sr)}
def get_and_parse_image(
image_url: str, image_url: str,
*, *,
allowed_local_media_path: str = "") -> MultiModalDataDict: image_mode: str = "RGB",
image = fetch_image(image_url, ) -> Image.Image:
allowed_local_media_path=allowed_local_media_path) """
return {"image": image} Load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
def get_and_parse_video(video_url: str) -> MultiModalDataDict: return self.load_from_url(
video = fetch_video(video_url) image_url,
return {"video": video} image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
async def fetch_image_async(
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: self,
audio, sr = await async_fetch_audio(audio_url)
return {"audio": (audio, sr)}
async def async_get_and_parse_image(
image_url: str, image_url: str,
*, *,
allowed_local_media_path: str = "") -> MultiModalDataDict: image_mode: str = "RGB",
image = await async_fetch_image( ) -> Image.Image:
image_url, allowed_local_media_path=allowed_local_media_path) """
return {"image": image} Asynchronously load a PIL image from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
return await self.load_from_url_async(
image_url,
image_io,
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
)
def fetch_video(
self,
video_url: str,
*,
image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray:
"""
Load video from a HTTP or base64 data URL.
"""
image_io = ImageMediaIO(image_mode=image_mode)
video_io = VideoMediaIO(image_io, num_frames=num_frames)
return self.load_from_url(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def fetch_video_async(
self,
video_url: str,
*,
image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray:
"""
Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format.
"""
image_io = ImageMediaIO(image_mode=image_mode)
video_io = VideoMediaIO(image_io, num_frames=num_frames)
return await self.load_from_url_async(
video_url,
video_io,
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
)
async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict: global_media_connector = MediaConnector()
video = await async_fetch_video(video_url) """The global :class:`MediaConnector` instance used by vLLM."""
return {"video": video}
fetch_audio = global_media_connector.fetch_audio
fetch_image = global_media_connector.fetch_image
fetch_video = global_media_connector.fetch_video
def encode_audio_base64( def encode_audio_base64(
@ -294,10 +256,8 @@ def encode_audio_base64(
sampling_rate: int, sampling_rate: int,
) -> str: ) -> str:
"""Encode audio as base64.""" """Encode audio as base64."""
buffered = BytesIO() audio_io = AudioMediaIO()
soundfile.write(buffered, audio, sampling_rate, format="WAV") return audio_io.encode_base64((audio, sampling_rate))
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def encode_image_base64( def encode_image_base64(
@ -311,29 +271,14 @@ def encode_image_base64(
By default, the image is converted into RGB format before being encoded. By default, the image is converted into RGB format before being encoded.
""" """
buffered = BytesIO() image_io = ImageMediaIO(image_mode=image_mode)
image = image.convert(image_mode) return image_io.encode_base64(image, image_format=format)
image.save(buffered, format)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
"""Load image from base64 format."""
return _load_image_from_bytes(base64.b64decode(image))
def encode_video_base64(frames: npt.NDArray) -> str: def encode_video_base64(frames: npt.NDArray) -> str:
base64_frames = [] image_io = ImageMediaIO()
frames_list = [frames[i] for i in range(frames.shape[0])] video_io = VideoMediaIO(image_io)
for frame in frames_list: return video_io.encode_base64(frames)
img_base64 = encode_image_base64(Image.fromarray(frame))
base64_frames.append(img_base64)
return ",".join(base64_frames)
def load_video_from_base64(video: Union[bytes, str]) -> npt.NDArray:
"""Load video from base64 format."""
return _load_video_from_bytes(base64.b64decode(video))
def resolve_visual_encoder_outputs( def resolve_visual_encoder_outputs(
@ -389,7 +334,7 @@ def repeat_and_pad_token(
repeat_count: int = 1, repeat_count: int = 1,
pad_token_left: Optional[_T] = None, pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None, pad_token_right: Optional[_T] = None,
) -> List[_T]: ) -> list[_T]:
replacement = [token] * repeat_count replacement = [token] * repeat_count
if pad_token_left is not None: if pad_token_left is not None:
replacement = [pad_token_left] + replacement replacement = [pad_token_left] + replacement
@ -402,13 +347,13 @@ def repeat_and_pad_token(
def repeat_and_pad_placeholder_tokens( def repeat_and_pad_placeholder_tokens(
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: List[int], prompt_token_ids: list[int],
*, *,
placeholder_token_id: int, placeholder_token_id: int,
repeat_count: Union[int, List[int]], repeat_count: Union[int, list[int]],
pad_token_left: Optional[int] = None, pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None, pad_token_right: Optional[int] = None,
) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]: ) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
if isinstance(repeat_count, int): if isinstance(repeat_count, int):
repeat_count = [repeat_count] repeat_count = [repeat_count]
@ -450,8 +395,8 @@ def repeat_and_pad_placeholder_tokens(
new_prompt += prompt_parts[i] + replacement_str new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1] new_prompt += prompt_parts[-1]
new_token_ids: List[int] = [] new_token_ids = list[int]()
placeholder_ranges: List[PlaceholderRange] = [] placeholder_ranges = list[PlaceholderRange]()
placeholder_token_idx = 0 placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids): for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id: if token == placeholder_token_id:
@ -481,7 +426,7 @@ def repeat_and_pad_placeholder_tokens(
def consecutive_placeholder_ranges( def consecutive_placeholder_ranges(
num_items: int, num_items: int,
item_size: int, item_size: int,
initial_offset: int = 0) -> List[PlaceholderRange]: initial_offset: int = 0) -> list[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size""" """Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [ return [

View File

@ -1,23 +1,32 @@
from functools import lru_cache import base64
from functools import lru_cache, partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import cv2 import cv2
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from PIL import Image
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import is_list_of from vllm.utils import PlaceholderModule, is_list_of
from .base import MultiModalData from .base import MediaIO, MultiModalData
from .image import ImagePlugin from .image import ImageMediaIO, ImagePlugin
from .inputs import MultiModalKwargs, VideoItem from .inputs import MultiModalKwargs, VideoItem
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
try:
import decord
except ImportError:
decord = PlaceholderModule("decord") # type: ignore[assignment]
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor) cached_get_video_processor = lru_cache(get_video_processor)
@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray,
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
sampled_frames = frames[frame_indices, ...] sampled_frames = frames[frame_indices, ...]
return sampled_frames return sampled_frames
class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__(
self,
image_io: ImageMediaIO,
*,
num_frames: int = 32,
) -> None:
super().__init__()
self.image_io = image_io
self.num_frames = num_frames
def load_bytes(self, data: bytes) -> npt.NDArray:
vr = decord.VideoReader(BytesIO(data), num_threads=1)
total_frame_num = len(vr)
num_frames = self.num_frames
if total_frame_num > num_frames:
uniform_sampled_frames = np.linspace(0,
total_frame_num - 1,
num_frames,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
else:
frame_idx = list(range(0, total_frame_num))
return vr.get_batch(frame_idx).asnumpy()
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
if media_type.lower() == "video/jpeg":
load_frame = partial(
self.image_io.load_base64,
"image/jpeg",
)
return np.stack([
np.array(load_frame(frame_data))
for frame_data in data.split(",")
])
return self.load_bytes(base64.b64decode(data))
def load_file(self, filepath: Path) -> npt.NDArray:
with filepath.open("rb") as f:
data = f.read()
return self.load_bytes(data)
def encode_base64(
self,
media: npt.NDArray,
*,
video_format: str = "JPEG",
) -> str:
video = media
if video_format == "JPEG":
encode_frame = partial(
self.image_io.encode_base64,
image_format=video_format,
)
return ",".join(
encode_frame(Image.fromarray(frame)) for frame in video)
msg = "Only JPEG format is supported for now."
raise NotImplementedError(msg)