mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 23:11:35 +08:00
[Misc] Abstract the logic for reading and writing media content (#11527)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
2c9b8ea2b0
commit
7af553ea30
@ -33,6 +33,7 @@ class MockModelConfig:
|
||||
hf_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
diff_sampling_param: Optional[dict] = None
|
||||
allowed_local_media_path: str = ""
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
@ -2,7 +2,6 @@ import warnings
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import ModelConfig
|
||||
@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input(
|
||||
image_data = mm_data.get("image")
|
||||
assert image_data is not None
|
||||
|
||||
if image_count == 1:
|
||||
assert isinstance(image_data, Image.Image)
|
||||
else:
|
||||
assert isinstance(image_data, list) and len(image_data) == image_count
|
||||
assert isinstance(image_data, list) and len(image_data) == image_count
|
||||
|
||||
|
||||
def test_parse_chat_messages_single_image(
|
||||
|
||||
@ -9,7 +9,7 @@ import pytest
|
||||
from PIL import Image, ImageChops
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
|
||||
from vllm.multimodal.utils import (MediaConnector,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
@ -23,7 +23,12 @@ TEST_IMAGE_URLS = [
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def url_images() -> Dict[str, Image.Image]:
|
||||
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
|
||||
connector = MediaConnector()
|
||||
|
||||
return {
|
||||
image_url: connector.fetch_image(image_url)
|
||||
for image_url in TEST_IMAGE_URLS
|
||||
}
|
||||
|
||||
|
||||
def get_supported_suffixes() -> Tuple[str, ...]:
|
||||
@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_fetch_image_http(image_url: str):
|
||||
image_sync = fetch_image(image_url)
|
||||
image_async = await async_fetch_image(image_url)
|
||||
connector = MediaConnector()
|
||||
|
||||
image_sync = connector.fetch_image(image_url)
|
||||
image_async = await connector.fetch_image_async(image_url)
|
||||
assert _image_equals(image_sync, image_async)
|
||||
|
||||
|
||||
@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
|
||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
image_url: str, suffix: str):
|
||||
connector = MediaConnector()
|
||||
url_image = url_images[image_url]
|
||||
|
||||
try:
|
||||
@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
data_image_sync = fetch_image(data_url)
|
||||
data_image_sync = connector.fetch_image(data_url)
|
||||
if _image_equals(url_image, Image.open(f)):
|
||||
assert _image_equals(url_image, data_image_sync)
|
||||
else:
|
||||
pass # Lossy format; only check that image can be opened
|
||||
|
||||
data_image_async = await async_fetch_image(data_url)
|
||||
data_image_async = await connector.fetch_image_async(data_url)
|
||||
assert _image_equals(data_image_sync, data_image_async)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_fetch_image_local_files(image_url: str):
|
||||
connector = MediaConnector()
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
origin_image = fetch_image(image_url)
|
||||
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
|
||||
|
||||
origin_image = connector.fetch_image(image_url)
|
||||
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
|
||||
quality=100,
|
||||
icc_profile=origin_image.info.get('icc_profile'))
|
||||
|
||||
image_async = await async_fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
|
||||
image_sync = fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
image_async = await local_connector.fetch_image_async(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}")
|
||||
image_sync = local_connector.fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}")
|
||||
# Check that the images are equal
|
||||
assert not ImageChops.difference(image_sync, image_async).getbbox()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
with pytest.raises(ValueError):
|
||||
await async_fetch_image(
|
||||
with pytest.raises(ValueError, match="must be a subpath"):
|
||||
await local_connector.fetch_image_async(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
with pytest.raises(RuntimeError, match="Cannot load local files"):
|
||||
await connector.fetch_image_async(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
with pytest.raises(ValueError):
|
||||
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
with pytest.raises(ValueError, match="must be a subpath"):
|
||||
local_connector.fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
with pytest.raises(RuntimeError, match="Cannot load local files"):
|
||||
connector.fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
|
||||
@ -21,12 +21,10 @@ class AudioAsset:
|
||||
name: Literal["winning_call", "mary_had_lamb"]
|
||||
|
||||
@property
|
||||
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
|
||||
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
||||
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
|
||||
s3_prefix=ASSET_DIR)
|
||||
y, sr = librosa.load(audio_path, sr=None)
|
||||
assert isinstance(sr, int)
|
||||
return y, sr
|
||||
return librosa.load(audio_path, sr=None)
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
|
||||
@ -6,7 +6,7 @@ from collections import defaultdict, deque
|
||||
from functools import lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
|
||||
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
|
||||
Literal, Optional, Tuple, TypeVar, Union, cast)
|
||||
|
||||
import jinja2.nodes
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
@ -23,6 +23,8 @@ from openai.types.chat import (
|
||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||
ChatCompletionToolMessageParam)
|
||||
from openai.types.chat.chat_completion_content_part_input_audio_param import (
|
||||
InputAudio)
|
||||
# yapf: enable
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
@ -31,11 +33,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
||||
async_get_and_parse_image,
|
||||
async_get_and_parse_video,
|
||||
get_and_parse_audio, get_and_parse_image,
|
||||
get_and_parse_video)
|
||||
from vllm.multimodal.utils import MediaConnector
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
@ -368,14 +366,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
self._tokenizer = tokenizer
|
||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||
if model_config.multimodal_config else {})
|
||||
self._consumed_items = {k: 0 for k in self._allowed_items}
|
||||
|
||||
self._items: List[_T] = []
|
||||
self._items_by_modality = defaultdict[str, list[_T]](list)
|
||||
|
||||
@property
|
||||
def model_config(self) -> ModelConfig:
|
||||
return self._model_config
|
||||
|
||||
@property
|
||||
def allowed_local_media_path(self):
|
||||
return self._model_config.allowed_local_media_path
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=None)
|
||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||
@ -435,38 +436,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
else:
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
@staticmethod
|
||||
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
||||
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
||||
|
||||
# Merge all the multi-modal items
|
||||
for single_mm_data in items:
|
||||
for mm_key, mm_item in single_mm_data.items():
|
||||
if isinstance(mm_item, list):
|
||||
mm_lists[mm_key].extend(mm_item)
|
||||
else:
|
||||
mm_lists[mm_key].append(mm_item)
|
||||
|
||||
# Unpack any single item lists for models that don't expect multiple.
|
||||
return {
|
||||
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
|
||||
for mm_key, mm_list in mm_lists.items()
|
||||
}
|
||||
|
||||
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||
"""
|
||||
Add a multi-modal item to the current prompt and returns the
|
||||
placeholder string to use, if any.
|
||||
"""
|
||||
allowed_count = self._allowed_items.get(modality, 1)
|
||||
current_count = self._consumed_items.get(modality, 0) + 1
|
||||
current_count = len(self._items_by_modality[modality]) + 1
|
||||
if current_count > allowed_count:
|
||||
raise ValueError(
|
||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||
"one request.")
|
||||
|
||||
self._consumed_items[modality] = current_count
|
||||
self._items.append(item)
|
||||
self._items_by_modality[modality].append(item)
|
||||
|
||||
return self._placeholder_str(modality, current_count)
|
||||
|
||||
@ -475,22 +457,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
||||
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||
|
||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
return self._combine(self._items) if self._items else None
|
||||
if self._items_by_modality:
|
||||
return dict(self._items_by_modality)
|
||||
|
||||
return None
|
||||
|
||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||
return MultiModalContentParser(self)
|
||||
|
||||
|
||||
class AsyncMultiModalItemTracker(
|
||||
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
||||
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||
|
||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||
if self._items:
|
||||
items = await asyncio.gather(*self._items)
|
||||
return self._combine(items)
|
||||
if self._items_by_modality:
|
||||
return {
|
||||
modality: await asyncio.gather(*items)
|
||||
for modality, items in self._items_by_modality.items()
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
@ -522,7 +508,7 @@ class BaseMultiModalContentParser(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -537,31 +523,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
|
||||
self._tracker = tracker
|
||||
|
||||
self._connector = MediaConnector(
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image = get_and_parse_image(image_url,
|
||||
allowed_local_media_path=self._tracker.
|
||||
_model_config.allowed_local_media_path)
|
||||
image = self._connector.fetch_image(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio = get_and_parse_audio(audio_url)
|
||||
audio = self._connector.fetch_audio(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
input_audio_data = input_audio.get("data","")
|
||||
input_audio_format = input_audio.get("format","")
|
||||
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
|
||||
audio = get_and_parse_audio(audio_url)
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
|
||||
placeholder = self._tracker.add("audio", audio)
|
||||
self._add_placeholder(placeholder)
|
||||
return self.parse_audio(audio_url)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = get_and_parse_video(video_url)
|
||||
video = self._connector.fetch_video(video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
@ -573,33 +559,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
self._connector = MediaConnector(
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
)
|
||||
|
||||
def parse_image(self, image_url: str) -> None:
|
||||
image_coro = async_get_and_parse_image(
|
||||
image_url,
|
||||
allowed_local_media_path=self._tracker._model_config.
|
||||
allowed_local_media_path)
|
||||
image_coro = self._connector.fetch_image_async(image_url)
|
||||
|
||||
placeholder = self._tracker.add("image", image_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_audio(self, audio_url: str) -> None:
|
||||
audio_coro = async_get_and_parse_audio(audio_url)
|
||||
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
|
||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
||||
input_audio_data = input_audio.get("data","")
|
||||
input_audio_format = input_audio.get("format","")
|
||||
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
|
||||
audio_coro = async_get_and_parse_audio(audio_url)
|
||||
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||
audio_data = input_audio.get("data", "")
|
||||
audio_format = input_audio.get("format", "")
|
||||
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||
|
||||
placeholder = self._tracker.add("audio", audio_coro)
|
||||
self._add_placeholder(placeholder)
|
||||
return self.parse_audio(audio_url)
|
||||
|
||||
def parse_video(self, video_url: str) -> None:
|
||||
video = async_get_and_parse_video(video_url)
|
||||
video = self._connector.fetch_video_async(video_url)
|
||||
|
||||
placeholder = self._tracker.add("video", video)
|
||||
self._add_placeholder(placeholder)
|
||||
@ -695,10 +679,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||
|
||||
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
|
||||
|
||||
# Define a mapping from part types to their corresponding parsing functions.
|
||||
MM_PARSER_MAP: Dict[str,
|
||||
Callable[[ChatCompletionContentPartParam],
|
||||
Union[str, Dict[str,str]]]] = {
|
||||
MM_PARSER_MAP: Dict[
|
||||
str,
|
||||
Callable[[ChatCompletionContentPartParam], _ContentPart],
|
||||
] = {
|
||||
"text":
|
||||
lambda part: _TextParser(part).get("text", ""),
|
||||
"image_url":
|
||||
@ -715,8 +702,7 @@ MM_PARSER_MAP: Dict[str,
|
||||
|
||||
|
||||
def _parse_chat_message_content_mm_part(
|
||||
part: ChatCompletionContentPartParam) -> Tuple[str,
|
||||
Union[str, Dict[str, str]]]:
|
||||
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
|
||||
"""
|
||||
Parses a given multi-modal content part based on its type.
|
||||
|
||||
@ -783,7 +769,7 @@ def _parse_chat_message_content_parts(
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
) -> List[ConversationMessage]:
|
||||
content: List[Union[str, Dict[str, str]]] = []
|
||||
content = list[_ContentPart]()
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
|
||||
@ -814,7 +800,7 @@ def _parse_chat_message_content_part(
|
||||
mm_parser: BaseMultiModalContentParser,
|
||||
*,
|
||||
wrap_dicts: bool,
|
||||
) -> Optional[Union[str, Dict[str, str]]]:
|
||||
) -> Optional[_ContentPart]:
|
||||
"""Parses a single part of a conversation. If wrap_dicts is True,
|
||||
structured dictionary pieces for texts and images will be
|
||||
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
|
||||
@ -823,8 +809,7 @@ def _parse_chat_message_content_part(
|
||||
with multimodal placeholders.
|
||||
"""
|
||||
if isinstance(part, str): # Handle plain text parts
|
||||
text = _TextParser(part)
|
||||
return text
|
||||
return part
|
||||
|
||||
# Handle structured dictionary parts
|
||||
part_type, content = _parse_chat_message_content_mm_part(part)
|
||||
@ -855,7 +840,7 @@ def _parse_chat_message_content_part(
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
if part_type == "input_audio":
|
||||
dict_content = cast(Dict[str, str], content)
|
||||
dict_content = cast(InputAudio, content)
|
||||
mm_parser.parse_input_audio(dict_content)
|
||||
return {'type': 'audio'} if wrap_dicts else None
|
||||
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .base import MultiModalPlugin
|
||||
from .base import MediaIO, MultiModalPlugin
|
||||
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
|
||||
|
||||
try:
|
||||
@ -12,6 +16,11 @@ try:
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
|
||||
class AudioPlugin(MultiModalPlugin):
|
||||
"""Plugin for audio data."""
|
||||
@ -39,3 +48,28 @@ def resample_audio(
|
||||
target_sr: float,
|
||||
) -> npt.NDArray[np.floating]:
|
||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
|
||||
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||
|
||||
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(BytesIO(data), sr=None)
|
||||
|
||||
def load_base64(
|
||||
self,
|
||||
media_type: str,
|
||||
data: str,
|
||||
) -> tuple[npt.NDArray, float]:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||
return librosa.load(filepath, sr=None)
|
||||
|
||||
def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
|
||||
audio, sr = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
soundfile.write(buffer, audio, sr, format="WAV")
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode('utf-8')
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
|
||||
Optional, Sequence, Tuple, Type, TypeVar, Union)
|
||||
|
||||
from torch import nn
|
||||
@ -118,7 +119,7 @@ class MultiModalPlugin(ABC):
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
data: MultiModalData[Any],
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
) -> MultiModalKwargs:
|
||||
"""
|
||||
Transform the data into a dictionary of model inputs using the
|
||||
@ -254,10 +255,10 @@ class MultiModalPlaceholderMap:
|
||||
"""
|
||||
|
||||
class IndexMap(NamedTuple):
|
||||
src: List[int]
|
||||
dest: List[int]
|
||||
src: list[int]
|
||||
dest: list[int]
|
||||
|
||||
src_ranges: List[range]
|
||||
src_ranges: list[range]
|
||||
"""
|
||||
The indices of the multi-modal embeddings that will replace the
|
||||
corresponding placeholder embeddings pointed to by ``dest_ranges``.
|
||||
@ -268,7 +269,7 @@ class MultiModalPlaceholderMap:
|
||||
The total number of flattened multi-modal embeddings.
|
||||
"""
|
||||
|
||||
dest_ranges: List[range]
|
||||
dest_ranges: list[range]
|
||||
"""
|
||||
The indices of the placeholder embeddings that will be replaced by the
|
||||
multimodal embeddings.
|
||||
@ -288,7 +289,7 @@ class MultiModalPlaceholderMap:
|
||||
@classmethod
|
||||
def from_seq_group(
|
||||
cls, seq_group: "SequenceGroupMetadata", positions: range
|
||||
) -> Tuple[Optional[MultiModalDataDict], Dict[str,
|
||||
) -> Tuple[Optional[MultiModalDataDict], dict[str,
|
||||
"MultiModalPlaceholderMap"]]:
|
||||
"""
|
||||
Returns the multi-modal items that intersect with the portion of a
|
||||
@ -376,9 +377,9 @@ class MultiModalPlaceholderMap:
|
||||
def append_items_from_seq_group(
|
||||
self,
|
||||
positions: range,
|
||||
multi_modal_items: List[_T],
|
||||
multi_modal_items: list[_T],
|
||||
multi_modal_placeholders: Sequence[PlaceholderRange],
|
||||
) -> List[_T]:
|
||||
) -> list[_T]:
|
||||
"""
|
||||
Adds the multi-modal items that intersect ```positions`` to this
|
||||
placeholder map and returns the intersecting items.
|
||||
@ -454,3 +455,22 @@ class MultiModalPlaceholderMap:
|
||||
|
||||
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
|
||||
dest=dest_indices)
|
||||
|
||||
|
||||
class MediaIO(ABC, Generic[_T]):
|
||||
|
||||
@abstractmethod
|
||||
def load_bytes(self, data: bytes) -> _T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_base64(self, media_type: str, data: str) -> _T:
|
||||
"""
|
||||
List of media types:
|
||||
https://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_file(self, filepath: Path) -> _T:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import base64
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
@ -9,7 +12,7 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import get_image_processor
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalPlugin
|
||||
from .base import MediaIO, MultiModalPlugin
|
||||
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -96,3 +99,39 @@ def rescale_image_size(image: Image.Image,
|
||||
if transpose >= 0:
|
||||
image = image.transpose(Image.Transpose(transpose))
|
||||
return image
|
||||
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
|
||||
def __init__(self, *, image_mode: str = "RGB") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_mode = image_mode
|
||||
|
||||
def load_bytes(self, data: bytes) -> Image.Image:
|
||||
image = Image.open(BytesIO(data))
|
||||
image.load()
|
||||
return image.convert(self.image_mode)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> Image.Image:
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> Image.Image:
|
||||
image = Image.open(filepath)
|
||||
image.load()
|
||||
return image.convert(self.image_mode)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: Image.Image,
|
||||
*,
|
||||
image_format: str = "JPEG",
|
||||
) -> str:
|
||||
image = media
|
||||
|
||||
with BytesIO() as buffer:
|
||||
image = image.convert(self.image_mode)
|
||||
image.save(buffer, image_format)
|
||||
data = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(data).decode('utf-8')
|
||||
|
||||
@ -1,8 +1,7 @@
|
||||
import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from io import BytesIO
|
||||
from typing import List, Optional, Tuple, TypeVar, Union
|
||||
from pathlib import Path
|
||||
from typing import Optional, TypeVar, Union
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@ -10,283 +9,246 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
from .inputs import MultiModalDataDict, PlaceholderRange
|
||||
|
||||
try:
|
||||
import decord
|
||||
except ImportError:
|
||||
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import soundfile
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
from .image import ImageMediaIO
|
||||
from .inputs import PlaceholderRange
|
||||
from .video import VideoMediaIO
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||
|
||||
|
||||
def _load_image_from_bytes(b: bytes) -> Image.Image:
|
||||
image = Image.open(BytesIO(b))
|
||||
image.load()
|
||||
return image
|
||||
_M = TypeVar("_M")
|
||||
|
||||
|
||||
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
|
||||
# Get the common path
|
||||
common_path = os.path.commonpath([
|
||||
os.path.abspath(image_path),
|
||||
os.path.abspath(allowed_local_media_path)
|
||||
])
|
||||
# Check if the common path is the same as allowed_local_media_path
|
||||
return common_path == os.path.abspath(allowed_local_media_path)
|
||||
class MediaConnector:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection: HTTPConnection = global_http_connection,
|
||||
*,
|
||||
allowed_local_media_path: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _load_image_from_file(image_url: str,
|
||||
allowed_local_media_path: str) -> Image.Image:
|
||||
if not allowed_local_media_path:
|
||||
raise ValueError("Invalid 'image_url': Cannot load local files without"
|
||||
"'--allowed-local-media-path'.")
|
||||
if allowed_local_media_path:
|
||||
if not os.path.exists(allowed_local_media_path):
|
||||
self.connection = connection
|
||||
|
||||
if allowed_local_media_path:
|
||||
allowed_local_media_path_ = Path(allowed_local_media_path)
|
||||
|
||||
if not allowed_local_media_path_.exists():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} does not exist.")
|
||||
if not allowed_local_media_path_.is_dir():
|
||||
raise ValueError(
|
||||
"Invalid `--allowed-local-media-path`: The path "
|
||||
f"{allowed_local_media_path_} must be a directory.")
|
||||
else:
|
||||
allowed_local_media_path_ = None
|
||||
|
||||
self.allowed_local_media_path = allowed_local_media_path_
|
||||
|
||||
def _load_data_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M:
|
||||
data_spec, data = url_spec.path.split(",", 1)
|
||||
media_type, data_type = data_spec.split(";", 1)
|
||||
|
||||
if data_type != "base64":
|
||||
msg = "Only base64 data URLs are supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return media_io.load_base64(media_type, data)
|
||||
|
||||
def _load_file_url(
|
||||
self,
|
||||
url_spec: ParseResult,
|
||||
media_io: MediaIO[_M],
|
||||
) -> _M:
|
||||
allowed_local_media_path = self.allowed_local_media_path
|
||||
if allowed_local_media_path is None:
|
||||
raise RuntimeError("Cannot load local files without "
|
||||
"`--allowed-local-media-path`.")
|
||||
|
||||
filepath = Path(url_spec.path)
|
||||
if allowed_local_media_path not in filepath.resolve().parents:
|
||||
raise ValueError(
|
||||
"Invalid '--allowed-local-media-path': "
|
||||
f"The path {allowed_local_media_path} does not exist.")
|
||||
if not os.path.isdir(allowed_local_media_path):
|
||||
raise ValueError(
|
||||
"Invalid '--allowed-local-media-path': "
|
||||
f"The path {allowed_local_media_path} must be a directory.")
|
||||
f"The file path {filepath} must be a subpath "
|
||||
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
|
||||
|
||||
# Only split once and assume the second part is the image path
|
||||
_, image_path = image_url.split("file://", 1)
|
||||
if not _is_subpath(image_path, allowed_local_media_path):
|
||||
raise ValueError(
|
||||
f"Invalid 'image_url': The file path {image_path} must"
|
||||
" be a subpath of '--allowed-local-media-path'"
|
||||
f" '{allowed_local_media_path}'.")
|
||||
return media_io.load_file(filepath)
|
||||
|
||||
image = Image.open(image_path)
|
||||
image.load()
|
||||
return image
|
||||
def load_from_url(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
|
||||
if url_spec.scheme.startswith("http"):
|
||||
connection = self.connection
|
||||
data = connection.get_bytes(url, timeout=fetch_timeout)
|
||||
|
||||
def _load_image_from_data_url(image_url: str) -> Image.Image:
|
||||
# Only split once and assume the second part is the base64 encoded image
|
||||
_, image_base64 = image_url.split(",", 1)
|
||||
return load_image_from_base64(image_base64)
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
def fetch_image(image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
allowed_local_media_path: str = "") -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from a HTTP or base64 data URL.
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
if image_url.startswith('http'):
|
||||
image_raw = global_http_connection.get_bytes(
|
||||
image_url,
|
||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
image = _load_image_from_data_url(image_url)
|
||||
elif image_url.startswith('file://'):
|
||||
image = _load_image_from_file(image_url, allowed_local_media_path)
|
||||
else:
|
||||
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
|
||||
"with either 'data:image', 'file://' or 'http'.")
|
||||
async def load_from_url_async(
|
||||
self,
|
||||
url: str,
|
||||
media_io: MediaIO[_M],
|
||||
*,
|
||||
fetch_timeout: Optional[int] = None,
|
||||
) -> _M:
|
||||
url_spec = urlparse(url)
|
||||
|
||||
return image.convert(image_mode)
|
||||
if url_spec.scheme.startswith("http"):
|
||||
connection = self.connection
|
||||
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
||||
|
||||
return media_io.load_bytes(data)
|
||||
|
||||
async def async_fetch_image(image_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
allowed_local_media_path: str = "") -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
||||
if url_spec.scheme == "data":
|
||||
return self._load_data_url(url_spec, media_io)
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
if image_url.startswith('http'):
|
||||
image_raw = await global_http_connection.async_get_bytes(
|
||||
image_url,
|
||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
image = _load_image_from_bytes(image_raw)
|
||||
if url_spec.scheme == "file":
|
||||
return self._load_file_url(url_spec, media_io)
|
||||
|
||||
elif image_url.startswith('data:image'):
|
||||
image = _load_image_from_data_url(image_url)
|
||||
elif image_url.startswith('file://'):
|
||||
image = _load_image_from_file(image_url, allowed_local_media_path)
|
||||
else:
|
||||
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
|
||||
"with either 'data:image', 'file://' or 'http'.")
|
||||
msg = "The URL must be either a HTTP, data or file URL."
|
||||
raise ValueError(msg)
|
||||
|
||||
return image.convert(image_mode)
|
||||
def fetch_audio(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO()
|
||||
|
||||
|
||||
def _load_video_from_bytes(b: bytes, num_frames: int = 32) -> npt.NDArray:
|
||||
video_path = BytesIO(b)
|
||||
vr = decord.VideoReader(video_path, num_threads=1)
|
||||
total_frame_num = len(vr)
|
||||
|
||||
if total_frame_num > num_frames:
|
||||
uniform_sampled_frames = np.linspace(0,
|
||||
total_frame_num - 1,
|
||||
num_frames,
|
||||
dtype=int)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
else:
|
||||
frame_idx = [i for i in range(0, total_frame_num)]
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
def _load_video_from_data_url(video_url: str) -> npt.NDArray:
|
||||
# Only split once and assume the second part is the base64 encoded video
|
||||
_, video_base64 = video_url.split(",", 1)
|
||||
|
||||
if video_url.startswith("data:video/jpeg;"):
|
||||
return np.stack([
|
||||
np.array(load_image_from_base64(frame_base64))
|
||||
for frame_base64 in video_base64.split(",")
|
||||
])
|
||||
|
||||
return load_video_from_base64(video_base64)
|
||||
|
||||
|
||||
def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray:
|
||||
"""
|
||||
Load video from a HTTP or base64 data URL.
|
||||
"""
|
||||
if video_url.startswith('http') or video_url.startswith('https'):
|
||||
video_raw = global_http_connection.get_bytes(
|
||||
video_url,
|
||||
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
video = _load_video_from_bytes(video_raw, num_frames)
|
||||
elif video_url.startswith('data:video'):
|
||||
video = _load_video_from_data_url(video_url)
|
||||
else:
|
||||
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
|
||||
"with either 'data:video' or 'http'.")
|
||||
return video
|
||||
|
||||
|
||||
async def async_fetch_video(video_url: str,
|
||||
*,
|
||||
num_frames: int = 32) -> npt.NDArray:
|
||||
"""
|
||||
Asynchronously load video from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
if video_url.startswith('http') or video_url.startswith('https'):
|
||||
video_raw = await global_http_connection.async_get_bytes(
|
||||
video_url,
|
||||
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
video = _load_video_from_bytes(video_raw, num_frames)
|
||||
elif video_url.startswith('data:video'):
|
||||
video = _load_video_from_data_url(video_url)
|
||||
else:
|
||||
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
|
||||
"with either 'data:video' or 'http'.")
|
||||
return video
|
||||
|
||||
|
||||
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Load audio from a URL.
|
||||
"""
|
||||
if audio_url.startswith("http"):
|
||||
audio_bytes = global_http_connection.get_bytes(
|
||||
return self.load_from_url(
|
||||
audio_url,
|
||||
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
elif audio_url.startswith("data:audio"):
|
||||
_, audio_base64 = audio_url.split(",", 1)
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
else:
|
||||
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
|
||||
"with either 'data:audio' or 'http'.")
|
||||
|
||||
return librosa.load(BytesIO(audio_bytes), sr=None)
|
||||
async def fetch_audio_async(
|
||||
self,
|
||||
audio_url: str,
|
||||
) -> tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
audio_io = AudioMediaIO()
|
||||
|
||||
|
||||
async def async_fetch_audio(
|
||||
audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
||||
"""
|
||||
Asynchronously fetch audio from a URL.
|
||||
"""
|
||||
if audio_url.startswith("http"):
|
||||
audio_bytes = await global_http_connection.async_get_bytes(
|
||||
return await self.load_from_url_async(
|
||||
audio_url,
|
||||
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
audio_io,
|
||||
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||
)
|
||||
elif audio_url.startswith("data:audio"):
|
||||
_, audio_base64 = audio_url.split(",", 1)
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
else:
|
||||
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
|
||||
"with either 'data:audio' or 'http'.")
|
||||
|
||||
return librosa.load(BytesIO(audio_bytes), sr=None)
|
||||
|
||||
|
||||
def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
||||
audio, sr = fetch_audio(audio_url)
|
||||
return {"audio": (audio, sr)}
|
||||
|
||||
|
||||
def get_and_parse_image(
|
||||
def fetch_image(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
allowed_local_media_path: str = "") -> MultiModalDataDict:
|
||||
image = fetch_image(image_url,
|
||||
allowed_local_media_path=allowed_local_media_path)
|
||||
return {"image": image}
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Load a PIL image from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
|
||||
def get_and_parse_video(video_url: str) -> MultiModalDataDict:
|
||||
video = fetch_video(video_url)
|
||||
return {"video": video}
|
||||
return self.load_from_url(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
||||
audio, sr = await async_fetch_audio(audio_url)
|
||||
return {"audio": (audio, sr)}
|
||||
|
||||
|
||||
async def async_get_and_parse_image(
|
||||
async def fetch_image_async(
|
||||
self,
|
||||
image_url: str,
|
||||
*,
|
||||
allowed_local_media_path: str = "") -> MultiModalDataDict:
|
||||
image = await async_fetch_image(
|
||||
image_url, allowed_local_media_path=allowed_local_media_path)
|
||||
return {"image": image}
|
||||
image_mode: str = "RGB",
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
|
||||
return await self.load_from_url_async(
|
||||
image_url,
|
||||
image_io,
|
||||
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
def fetch_video(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
num_frames: int = 32,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Load video from a HTTP or base64 data URL.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
||||
|
||||
return self.load_from_url(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
async def fetch_video_async(
|
||||
self,
|
||||
video_url: str,
|
||||
*,
|
||||
image_mode: str = "RGB",
|
||||
num_frames: int = 32,
|
||||
) -> npt.NDArray:
|
||||
"""
|
||||
Asynchronously load video from a HTTP or base64 data URL.
|
||||
|
||||
By default, the image is converted into RGB format.
|
||||
"""
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
||||
|
||||
return await self.load_from_url_async(
|
||||
video_url,
|
||||
video_io,
|
||||
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||
)
|
||||
|
||||
|
||||
async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict:
|
||||
video = await async_fetch_video(video_url)
|
||||
return {"video": video}
|
||||
global_media_connector = MediaConnector()
|
||||
"""The global :class:`MediaConnector` instance used by vLLM."""
|
||||
|
||||
fetch_audio = global_media_connector.fetch_audio
|
||||
fetch_image = global_media_connector.fetch_image
|
||||
fetch_video = global_media_connector.fetch_video
|
||||
|
||||
|
||||
def encode_audio_base64(
|
||||
@ -294,10 +256,8 @@ def encode_audio_base64(
|
||||
sampling_rate: int,
|
||||
) -> str:
|
||||
"""Encode audio as base64."""
|
||||
buffered = BytesIO()
|
||||
soundfile.write(buffered, audio, sampling_rate, format="WAV")
|
||||
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
audio_io = AudioMediaIO()
|
||||
return audio_io.encode_base64((audio, sampling_rate))
|
||||
|
||||
|
||||
def encode_image_base64(
|
||||
@ -311,29 +271,14 @@ def encode_image_base64(
|
||||
|
||||
By default, the image is converted into RGB format before being encoded.
|
||||
"""
|
||||
buffered = BytesIO()
|
||||
image = image.convert(image_mode)
|
||||
image.save(buffered, format)
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
|
||||
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
|
||||
"""Load image from base64 format."""
|
||||
return _load_image_from_bytes(base64.b64decode(image))
|
||||
image_io = ImageMediaIO(image_mode=image_mode)
|
||||
return image_io.encode_base64(image, image_format=format)
|
||||
|
||||
|
||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
base64_frames = []
|
||||
frames_list = [frames[i] for i in range(frames.shape[0])]
|
||||
for frame in frames_list:
|
||||
img_base64 = encode_image_base64(Image.fromarray(frame))
|
||||
base64_frames.append(img_base64)
|
||||
return ",".join(base64_frames)
|
||||
|
||||
|
||||
def load_video_from_base64(video: Union[bytes, str]) -> npt.NDArray:
|
||||
"""Load video from base64 format."""
|
||||
return _load_video_from_bytes(base64.b64decode(video))
|
||||
image_io = ImageMediaIO()
|
||||
video_io = VideoMediaIO(image_io)
|
||||
return video_io.encode_base64(frames)
|
||||
|
||||
|
||||
def resolve_visual_encoder_outputs(
|
||||
@ -389,7 +334,7 @@ def repeat_and_pad_token(
|
||||
repeat_count: int = 1,
|
||||
pad_token_left: Optional[_T] = None,
|
||||
pad_token_right: Optional[_T] = None,
|
||||
) -> List[_T]:
|
||||
) -> list[_T]:
|
||||
replacement = [token] * repeat_count
|
||||
if pad_token_left is not None:
|
||||
replacement = [pad_token_left] + replacement
|
||||
@ -402,13 +347,13 @@ def repeat_and_pad_token(
|
||||
def repeat_and_pad_placeholder_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
prompt_token_ids: list[int],
|
||||
*,
|
||||
placeholder_token_id: int,
|
||||
repeat_count: Union[int, List[int]],
|
||||
repeat_count: Union[int, list[int]],
|
||||
pad_token_left: Optional[int] = None,
|
||||
pad_token_right: Optional[int] = None,
|
||||
) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]:
|
||||
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
|
||||
if isinstance(repeat_count, int):
|
||||
repeat_count = [repeat_count]
|
||||
|
||||
@ -450,8 +395,8 @@ def repeat_and_pad_placeholder_tokens(
|
||||
new_prompt += prompt_parts[i] + replacement_str
|
||||
new_prompt += prompt_parts[-1]
|
||||
|
||||
new_token_ids: List[int] = []
|
||||
placeholder_ranges: List[PlaceholderRange] = []
|
||||
new_token_ids = list[int]()
|
||||
placeholder_ranges = list[PlaceholderRange]()
|
||||
placeholder_token_idx = 0
|
||||
for i, token in enumerate(prompt_token_ids):
|
||||
if token == placeholder_token_id:
|
||||
@ -481,7 +426,7 @@ def repeat_and_pad_placeholder_tokens(
|
||||
def consecutive_placeholder_ranges(
|
||||
num_items: int,
|
||||
item_size: int,
|
||||
initial_offset: int = 0) -> List[PlaceholderRange]:
|
||||
initial_offset: int = 0) -> list[PlaceholderRange]:
|
||||
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
|
||||
|
||||
return [
|
||||
|
||||
@ -1,23 +1,32 @@
|
||||
from functools import lru_cache
|
||||
import base64
|
||||
from functools import lru_cache, partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import get_video_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils import PlaceholderModule, is_list_of
|
||||
|
||||
from .base import MultiModalData
|
||||
from .image import ImagePlugin
|
||||
from .base import MediaIO, MultiModalData
|
||||
from .image import ImageMediaIO, ImagePlugin
|
||||
from .inputs import MultiModalKwargs, VideoItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
try:
|
||||
import decord
|
||||
except ImportError:
|
||||
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
cached_get_video_processor = lru_cache(get_video_processor)
|
||||
@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray,
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
*,
|
||||
num_frames: int = 32,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
|
||||
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||
vr = decord.VideoReader(BytesIO(data), num_threads=1)
|
||||
total_frame_num = len(vr)
|
||||
|
||||
num_frames = self.num_frames
|
||||
if total_frame_num > num_frames:
|
||||
uniform_sampled_frames = np.linspace(0,
|
||||
total_frame_num - 1,
|
||||
num_frames,
|
||||
dtype=int)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
else:
|
||||
frame_idx = list(range(0, total_frame_num))
|
||||
|
||||
return vr.get_batch(frame_idx).asnumpy()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
load_frame = partial(
|
||||
self.image_io.load_base64,
|
||||
"image/jpeg",
|
||||
)
|
||||
|
||||
return np.stack([
|
||||
np.array(load_frame(frame_data))
|
||||
for frame_data in data.split(",")
|
||||
])
|
||||
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> npt.NDArray:
|
||||
with filepath.open("rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return self.load_bytes(data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: npt.NDArray,
|
||||
*,
|
||||
video_format: str = "JPEG",
|
||||
) -> str:
|
||||
video = media
|
||||
|
||||
if video_format == "JPEG":
|
||||
encode_frame = partial(
|
||||
self.image_io.encode_base64,
|
||||
image_format=video_format,
|
||||
)
|
||||
|
||||
return ",".join(
|
||||
encode_frame(Image.fromarray(frame)) for frame in video)
|
||||
|
||||
msg = "Only JPEG format is supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user