mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 13:41:20 +08:00
[Misc] Abstract the logic for reading and writing media content (#11527)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
2c9b8ea2b0
commit
7af553ea30
@ -33,6 +33,7 @@ class MockModelConfig:
|
|||||||
hf_config = MockHFConfig()
|
hf_config = MockHFConfig()
|
||||||
logits_processor_pattern = None
|
logits_processor_pattern = None
|
||||||
diff_sampling_param: Optional[dict] = None
|
diff_sampling_param: Optional[dict] = None
|
||||||
|
allowed_local_media_path: str = ""
|
||||||
|
|
||||||
def get_diff_sampling_param(self):
|
def get_diff_sampling_param(self):
|
||||||
return self.diff_sampling_param or {}
|
return self.diff_sampling_param or {}
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import warnings
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -91,10 +90,7 @@ def _assert_mm_data_is_image_input(
|
|||||||
image_data = mm_data.get("image")
|
image_data = mm_data.get("image")
|
||||||
assert image_data is not None
|
assert image_data is not None
|
||||||
|
|
||||||
if image_count == 1:
|
assert isinstance(image_data, list) and len(image_data) == image_count
|
||||||
assert isinstance(image_data, Image.Image)
|
|
||||||
else:
|
|
||||||
assert isinstance(image_data, list) and len(image_data) == image_count
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_chat_messages_single_image(
|
def test_parse_chat_messages_single_image(
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import pytest
|
|||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
from transformers import AutoConfig, AutoTokenizer
|
from transformers import AutoConfig, AutoTokenizer
|
||||||
|
|
||||||
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
|
from vllm.multimodal.utils import (MediaConnector,
|
||||||
repeat_and_pad_placeholder_tokens)
|
repeat_and_pad_placeholder_tokens)
|
||||||
|
|
||||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||||
@ -23,7 +23,12 @@ TEST_IMAGE_URLS = [
|
|||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def url_images() -> Dict[str, Image.Image]:
|
def url_images() -> Dict[str, Image.Image]:
|
||||||
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
|
connector = MediaConnector()
|
||||||
|
|
||||||
|
return {
|
||||||
|
image_url: connector.fetch_image(image_url)
|
||||||
|
for image_url in TEST_IMAGE_URLS
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_supported_suffixes() -> Tuple[str, ...]:
|
def get_supported_suffixes() -> Tuple[str, ...]:
|
||||||
@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
async def test_fetch_image_http(image_url: str):
|
async def test_fetch_image_http(image_url: str):
|
||||||
image_sync = fetch_image(image_url)
|
connector = MediaConnector()
|
||||||
image_async = await async_fetch_image(image_url)
|
|
||||||
|
image_sync = connector.fetch_image(image_url)
|
||||||
|
image_async = await connector.fetch_image_async(image_url)
|
||||||
assert _image_equals(image_sync, image_async)
|
assert _image_equals(image_sync, image_async)
|
||||||
|
|
||||||
|
|
||||||
@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
|
|||||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||||
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||||
image_url: str, suffix: str):
|
image_url: str, suffix: str):
|
||||||
|
connector = MediaConnector()
|
||||||
url_image = url_images[image_url]
|
url_image = url_images[image_url]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
|||||||
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
||||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||||
|
|
||||||
data_image_sync = fetch_image(data_url)
|
data_image_sync = connector.fetch_image(data_url)
|
||||||
if _image_equals(url_image, Image.open(f)):
|
if _image_equals(url_image, Image.open(f)):
|
||||||
assert _image_equals(url_image, data_image_sync)
|
assert _image_equals(url_image, data_image_sync)
|
||||||
else:
|
else:
|
||||||
pass # Lossy format; only check that image can be opened
|
pass # Lossy format; only check that image can be opened
|
||||||
|
|
||||||
data_image_async = await async_fetch_image(data_url)
|
data_image_async = await connector.fetch_image_async(data_url)
|
||||||
assert _image_equals(data_image_sync, data_image_async)
|
assert _image_equals(data_image_sync, data_image_async)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||||
async def test_fetch_image_local_files(image_url: str):
|
async def test_fetch_image_local_files(image_url: str):
|
||||||
|
connector = MediaConnector()
|
||||||
|
|
||||||
with TemporaryDirectory() as temp_dir:
|
with TemporaryDirectory() as temp_dir:
|
||||||
origin_image = fetch_image(image_url)
|
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
|
||||||
|
|
||||||
|
origin_image = connector.fetch_image(image_url)
|
||||||
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
|
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
|
||||||
quality=100,
|
quality=100,
|
||||||
icc_profile=origin_image.info.get('icc_profile'))
|
icc_profile=origin_image.info.get('icc_profile'))
|
||||||
|
|
||||||
image_async = await async_fetch_image(
|
image_async = await local_connector.fetch_image_async(
|
||||||
f"file://{temp_dir}/{os.path.basename(image_url)}",
|
f"file://{temp_dir}/{os.path.basename(image_url)}")
|
||||||
allowed_local_media_path=temp_dir)
|
image_sync = local_connector.fetch_image(
|
||||||
|
f"file://{temp_dir}/{os.path.basename(image_url)}")
|
||||||
image_sync = fetch_image(
|
|
||||||
f"file://{temp_dir}/{os.path.basename(image_url)}",
|
|
||||||
allowed_local_media_path=temp_dir)
|
|
||||||
# Check that the images are equal
|
# Check that the images are equal
|
||||||
assert not ImageChops.difference(image_sync, image_async).getbbox()
|
assert not ImageChops.difference(image_sync, image_async).getbbox()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="must be a subpath"):
|
||||||
await async_fetch_image(
|
await local_connector.fetch_image_async(
|
||||||
f"file://{temp_dir}/../{os.path.basename(image_url)}",
|
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||||
allowed_local_media_path=temp_dir)
|
with pytest.raises(RuntimeError, match="Cannot load local files"):
|
||||||
with pytest.raises(ValueError):
|
await connector.fetch_image_async(
|
||||||
await async_fetch_image(
|
|
||||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError, match="must be a subpath"):
|
||||||
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
|
local_connector.fetch_image(
|
||||||
allowed_local_media_path=temp_dir)
|
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(RuntimeError, match="Cannot load local files"):
|
||||||
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
connector.fetch_image(
|
||||||
|
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||||
|
|||||||
@ -21,12 +21,10 @@ class AudioAsset:
|
|||||||
name: Literal["winning_call", "mary_had_lamb"]
|
name: Literal["winning_call", "mary_had_lamb"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def audio_and_sample_rate(self) -> tuple[npt.NDArray, int]:
|
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
|
||||||
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
|
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
|
||||||
s3_prefix=ASSET_DIR)
|
s3_prefix=ASSET_DIR)
|
||||||
y, sr = librosa.load(audio_path, sr=None)
|
return librosa.load(audio_path, sr=None)
|
||||||
assert isinstance(sr, int)
|
|
||||||
return y, sr
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url(self) -> str:
|
def url(self) -> str:
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from collections import defaultdict, deque
|
|||||||
from functools import lru_cache, partial
|
from functools import lru_cache, partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
|
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
|
||||||
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
|
Literal, Optional, Tuple, TypeVar, Union, cast)
|
||||||
|
|
||||||
import jinja2.nodes
|
import jinja2.nodes
|
||||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||||
@ -23,6 +23,8 @@ from openai.types.chat import (
|
|||||||
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
|
||||||
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
from openai.types.chat import (ChatCompletionMessageToolCallParam,
|
||||||
ChatCompletionToolMessageParam)
|
ChatCompletionToolMessageParam)
|
||||||
|
from openai.types.chat.chat_completion_content_part_input_audio_param import (
|
||||||
|
InputAudio)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
# pydantic needs the TypedDict from typing_extensions
|
# pydantic needs the TypedDict from typing_extensions
|
||||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
@ -31,11 +33,7 @@ from typing_extensions import Required, TypeAlias, TypedDict
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import (async_get_and_parse_audio,
|
from vllm.multimodal.utils import MediaConnector
|
||||||
async_get_and_parse_image,
|
|
||||||
async_get_and_parse_video,
|
|
||||||
get_and_parse_audio, get_and_parse_image,
|
|
||||||
get_and_parse_video)
|
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
@ -368,14 +366,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
|
||||||
if model_config.multimodal_config else {})
|
if model_config.multimodal_config else {})
|
||||||
self._consumed_items = {k: 0 for k in self._allowed_items}
|
|
||||||
|
|
||||||
self._items: List[_T] = []
|
self._items_by_modality = defaultdict[str, list[_T]](list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_config(self) -> ModelConfig:
|
def model_config(self) -> ModelConfig:
|
||||||
return self._model_config
|
return self._model_config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allowed_local_media_path(self):
|
||||||
|
return self._model_config.allowed_local_media_path
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str:
|
||||||
@ -435,38 +436,19 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
else:
|
else:
|
||||||
raise TypeError(f"Unknown modality: {modality}")
|
raise TypeError(f"Unknown modality: {modality}")
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict:
|
|
||||||
mm_lists: Mapping[str, List[object]] = defaultdict(list)
|
|
||||||
|
|
||||||
# Merge all the multi-modal items
|
|
||||||
for single_mm_data in items:
|
|
||||||
for mm_key, mm_item in single_mm_data.items():
|
|
||||||
if isinstance(mm_item, list):
|
|
||||||
mm_lists[mm_key].extend(mm_item)
|
|
||||||
else:
|
|
||||||
mm_lists[mm_key].append(mm_item)
|
|
||||||
|
|
||||||
# Unpack any single item lists for models that don't expect multiple.
|
|
||||||
return {
|
|
||||||
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
|
|
||||||
for mm_key, mm_list in mm_lists.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
def add(self, modality: ModalityStr, item: _T) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Add a multi-modal item to the current prompt and returns the
|
Add a multi-modal item to the current prompt and returns the
|
||||||
placeholder string to use, if any.
|
placeholder string to use, if any.
|
||||||
"""
|
"""
|
||||||
allowed_count = self._allowed_items.get(modality, 1)
|
allowed_count = self._allowed_items.get(modality, 1)
|
||||||
current_count = self._consumed_items.get(modality, 0) + 1
|
current_count = len(self._items_by_modality[modality]) + 1
|
||||||
if current_count > allowed_count:
|
if current_count > allowed_count:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"At most {allowed_count} {modality}(s) may be provided in "
|
f"At most {allowed_count} {modality}(s) may be provided in "
|
||||||
"one request.")
|
"one request.")
|
||||||
|
|
||||||
self._consumed_items[modality] = current_count
|
self._items_by_modality[modality].append(item)
|
||||||
self._items.append(item)
|
|
||||||
|
|
||||||
return self._placeholder_str(modality, current_count)
|
return self._placeholder_str(modality, current_count)
|
||||||
|
|
||||||
@ -475,22 +457,26 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]):
|
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
|
||||||
|
|
||||||
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
return self._combine(self._items) if self._items else None
|
if self._items_by_modality:
|
||||||
|
return dict(self._items_by_modality)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def create_parser(self) -> "BaseMultiModalContentParser":
|
def create_parser(self) -> "BaseMultiModalContentParser":
|
||||||
return MultiModalContentParser(self)
|
return MultiModalContentParser(self)
|
||||||
|
|
||||||
|
|
||||||
class AsyncMultiModalItemTracker(
|
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
|
||||||
BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]):
|
|
||||||
|
|
||||||
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
|
||||||
if self._items:
|
if self._items_by_modality:
|
||||||
items = await asyncio.gather(*self._items)
|
return {
|
||||||
return self._combine(items)
|
modality: await asyncio.gather(*items)
|
||||||
|
for modality, items in self._items_by_modality.items()
|
||||||
|
}
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -522,7 +508,7 @@ class BaseMultiModalContentParser(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -537,31 +523,31 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
|
|
||||||
self._tracker = tracker
|
self._tracker = tracker
|
||||||
|
|
||||||
|
self._connector = MediaConnector(
|
||||||
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
|
)
|
||||||
|
|
||||||
def parse_image(self, image_url: str) -> None:
|
def parse_image(self, image_url: str) -> None:
|
||||||
image = get_and_parse_image(image_url,
|
image = self._connector.fetch_image(image_url)
|
||||||
allowed_local_media_path=self._tracker.
|
|
||||||
_model_config.allowed_local_media_path)
|
|
||||||
|
|
||||||
placeholder = self._tracker.add("image", image)
|
placeholder = self._tracker.add("image", image)
|
||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
def parse_audio(self, audio_url: str) -> None:
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
audio = get_and_parse_audio(audio_url)
|
audio = self._connector.fetch_audio(audio_url)
|
||||||
|
|
||||||
placeholder = self._tracker.add("audio", audio)
|
placeholder = self._tracker.add("audio", audio)
|
||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||||
input_audio_data = input_audio.get("data","")
|
audio_data = input_audio.get("data", "")
|
||||||
input_audio_format = input_audio.get("format","")
|
audio_format = input_audio.get("format", "")
|
||||||
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
|
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||||
audio = get_and_parse_audio(audio_url)
|
|
||||||
|
|
||||||
placeholder = self._tracker.add("audio", audio)
|
return self.parse_audio(audio_url)
|
||||||
self._add_placeholder(placeholder)
|
|
||||||
|
|
||||||
def parse_video(self, video_url: str) -> None:
|
def parse_video(self, video_url: str) -> None:
|
||||||
video = get_and_parse_video(video_url)
|
video = self._connector.fetch_video(video_url)
|
||||||
|
|
||||||
placeholder = self._tracker.add("video", video)
|
placeholder = self._tracker.add("video", video)
|
||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
@ -573,33 +559,31 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._tracker = tracker
|
self._tracker = tracker
|
||||||
|
self._connector = MediaConnector(
|
||||||
|
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||||
|
)
|
||||||
|
|
||||||
def parse_image(self, image_url: str) -> None:
|
def parse_image(self, image_url: str) -> None:
|
||||||
image_coro = async_get_and_parse_image(
|
image_coro = self._connector.fetch_image_async(image_url)
|
||||||
image_url,
|
|
||||||
allowed_local_media_path=self._tracker._model_config.
|
|
||||||
allowed_local_media_path)
|
|
||||||
|
|
||||||
placeholder = self._tracker.add("image", image_coro)
|
placeholder = self._tracker.add("image", image_coro)
|
||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
def parse_audio(self, audio_url: str) -> None:
|
def parse_audio(self, audio_url: str) -> None:
|
||||||
audio_coro = async_get_and_parse_audio(audio_url)
|
audio_coro = self._connector.fetch_audio_async(audio_url)
|
||||||
|
|
||||||
placeholder = self._tracker.add("audio", audio_coro)
|
placeholder = self._tracker.add("audio", audio_coro)
|
||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
|
|
||||||
def parse_input_audio(self, input_audio: Dict[str, str]) -> None:
|
def parse_input_audio(self, input_audio: InputAudio) -> None:
|
||||||
input_audio_data = input_audio.get("data","")
|
audio_data = input_audio.get("data", "")
|
||||||
input_audio_format = input_audio.get("format","")
|
audio_format = input_audio.get("format", "")
|
||||||
audio_url = f"data:audio/{input_audio_format};base64,{input_audio_data}"
|
audio_url = f"data:audio/{audio_format};base64,{audio_data}"
|
||||||
audio_coro = async_get_and_parse_audio(audio_url)
|
|
||||||
|
|
||||||
placeholder = self._tracker.add("audio", audio_coro)
|
return self.parse_audio(audio_url)
|
||||||
self._add_placeholder(placeholder)
|
|
||||||
|
|
||||||
def parse_video(self, video_url: str) -> None:
|
def parse_video(self, video_url: str) -> None:
|
||||||
video = async_get_and_parse_video(video_url)
|
video = self._connector.fetch_video_async(video_url)
|
||||||
|
|
||||||
placeholder = self._tracker.add("video", video)
|
placeholder = self._tracker.add("video", video)
|
||||||
self._add_placeholder(placeholder)
|
self._add_placeholder(placeholder)
|
||||||
@ -695,10 +679,13 @@ _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
|||||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||||
|
|
||||||
|
_ContentPart: TypeAlias = Union[str, Dict[str, str], InputAudio]
|
||||||
|
|
||||||
# Define a mapping from part types to their corresponding parsing functions.
|
# Define a mapping from part types to their corresponding parsing functions.
|
||||||
MM_PARSER_MAP: Dict[str,
|
MM_PARSER_MAP: Dict[
|
||||||
Callable[[ChatCompletionContentPartParam],
|
str,
|
||||||
Union[str, Dict[str,str]]]] = {
|
Callable[[ChatCompletionContentPartParam], _ContentPart],
|
||||||
|
] = {
|
||||||
"text":
|
"text":
|
||||||
lambda part: _TextParser(part).get("text", ""),
|
lambda part: _TextParser(part).get("text", ""),
|
||||||
"image_url":
|
"image_url":
|
||||||
@ -715,8 +702,7 @@ MM_PARSER_MAP: Dict[str,
|
|||||||
|
|
||||||
|
|
||||||
def _parse_chat_message_content_mm_part(
|
def _parse_chat_message_content_mm_part(
|
||||||
part: ChatCompletionContentPartParam) -> Tuple[str,
|
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
|
||||||
Union[str, Dict[str, str]]]:
|
|
||||||
"""
|
"""
|
||||||
Parses a given multi-modal content part based on its type.
|
Parses a given multi-modal content part based on its type.
|
||||||
|
|
||||||
@ -783,7 +769,7 @@ def _parse_chat_message_content_parts(
|
|||||||
*,
|
*,
|
||||||
wrap_dicts: bool,
|
wrap_dicts: bool,
|
||||||
) -> List[ConversationMessage]:
|
) -> List[ConversationMessage]:
|
||||||
content: List[Union[str, Dict[str, str]]] = []
|
content = list[_ContentPart]()
|
||||||
|
|
||||||
mm_parser = mm_tracker.create_parser()
|
mm_parser = mm_tracker.create_parser()
|
||||||
|
|
||||||
@ -814,7 +800,7 @@ def _parse_chat_message_content_part(
|
|||||||
mm_parser: BaseMultiModalContentParser,
|
mm_parser: BaseMultiModalContentParser,
|
||||||
*,
|
*,
|
||||||
wrap_dicts: bool,
|
wrap_dicts: bool,
|
||||||
) -> Optional[Union[str, Dict[str, str]]]:
|
) -> Optional[_ContentPart]:
|
||||||
"""Parses a single part of a conversation. If wrap_dicts is True,
|
"""Parses a single part of a conversation. If wrap_dicts is True,
|
||||||
structured dictionary pieces for texts and images will be
|
structured dictionary pieces for texts and images will be
|
||||||
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
|
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
|
||||||
@ -823,8 +809,7 @@ def _parse_chat_message_content_part(
|
|||||||
with multimodal placeholders.
|
with multimodal placeholders.
|
||||||
"""
|
"""
|
||||||
if isinstance(part, str): # Handle plain text parts
|
if isinstance(part, str): # Handle plain text parts
|
||||||
text = _TextParser(part)
|
return part
|
||||||
return text
|
|
||||||
|
|
||||||
# Handle structured dictionary parts
|
# Handle structured dictionary parts
|
||||||
part_type, content = _parse_chat_message_content_mm_part(part)
|
part_type, content = _parse_chat_message_content_mm_part(part)
|
||||||
@ -855,7 +840,7 @@ def _parse_chat_message_content_part(
|
|||||||
return {'type': 'audio'} if wrap_dicts else None
|
return {'type': 'audio'} if wrap_dicts else None
|
||||||
|
|
||||||
if part_type == "input_audio":
|
if part_type == "input_audio":
|
||||||
dict_content = cast(Dict[str, str], content)
|
dict_content = cast(InputAudio, content)
|
||||||
mm_parser.parse_input_audio(dict_content)
|
mm_parser.parse_input_audio(dict_content)
|
||||||
return {'type': 'audio'} if wrap_dicts else None
|
return {'type': 'audio'} if wrap_dicts else None
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,14 @@
|
|||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
|
||||||
from vllm.inputs.registry import InputContext
|
from vllm.inputs.registry import InputContext
|
||||||
from vllm.utils import PlaceholderModule
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
from .base import MultiModalPlugin
|
from .base import MediaIO, MultiModalPlugin
|
||||||
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
|
from .inputs import AudioItem, MultiModalData, MultiModalKwargs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -12,6 +16,11 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
try:
|
||||||
|
import soundfile
|
||||||
|
except ImportError:
|
||||||
|
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
class AudioPlugin(MultiModalPlugin):
|
class AudioPlugin(MultiModalPlugin):
|
||||||
"""Plugin for audio data."""
|
"""Plugin for audio data."""
|
||||||
@ -39,3 +48,28 @@ def resample_audio(
|
|||||||
target_sr: float,
|
target_sr: float,
|
||||||
) -> npt.NDArray[np.floating]:
|
) -> npt.NDArray[np.floating]:
|
||||||
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
return librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
|
||||||
|
|
||||||
|
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
|
||||||
|
return librosa.load(BytesIO(data), sr=None)
|
||||||
|
|
||||||
|
def load_base64(
|
||||||
|
self,
|
||||||
|
media_type: str,
|
||||||
|
data: str,
|
||||||
|
) -> tuple[npt.NDArray, float]:
|
||||||
|
return self.load_bytes(base64.b64decode(data))
|
||||||
|
|
||||||
|
def load_file(self, filepath: Path) -> tuple[npt.NDArray, float]:
|
||||||
|
return librosa.load(filepath, sr=None)
|
||||||
|
|
||||||
|
def encode_base64(self, media: tuple[npt.NDArray, float]) -> str:
|
||||||
|
audio, sr = media
|
||||||
|
|
||||||
|
with BytesIO() as buffer:
|
||||||
|
soundfile.write(buffer, audio, sr, format="WAV")
|
||||||
|
data = buffer.getvalue()
|
||||||
|
|
||||||
|
return base64.b64encode(data).decode('utf-8')
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
|
from pathlib import Path
|
||||||
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, NamedTuple,
|
||||||
Optional, Sequence, Tuple, Type, TypeVar, Union)
|
Optional, Sequence, Tuple, Type, TypeVar, Union)
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -118,7 +119,7 @@ class MultiModalPlugin(ABC):
|
|||||||
self,
|
self,
|
||||||
model_config: "ModelConfig",
|
model_config: "ModelConfig",
|
||||||
data: MultiModalData[Any],
|
data: MultiModalData[Any],
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]],
|
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||||
) -> MultiModalKwargs:
|
) -> MultiModalKwargs:
|
||||||
"""
|
"""
|
||||||
Transform the data into a dictionary of model inputs using the
|
Transform the data into a dictionary of model inputs using the
|
||||||
@ -254,10 +255,10 @@ class MultiModalPlaceholderMap:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
class IndexMap(NamedTuple):
|
class IndexMap(NamedTuple):
|
||||||
src: List[int]
|
src: list[int]
|
||||||
dest: List[int]
|
dest: list[int]
|
||||||
|
|
||||||
src_ranges: List[range]
|
src_ranges: list[range]
|
||||||
"""
|
"""
|
||||||
The indices of the multi-modal embeddings that will replace the
|
The indices of the multi-modal embeddings that will replace the
|
||||||
corresponding placeholder embeddings pointed to by ``dest_ranges``.
|
corresponding placeholder embeddings pointed to by ``dest_ranges``.
|
||||||
@ -268,7 +269,7 @@ class MultiModalPlaceholderMap:
|
|||||||
The total number of flattened multi-modal embeddings.
|
The total number of flattened multi-modal embeddings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dest_ranges: List[range]
|
dest_ranges: list[range]
|
||||||
"""
|
"""
|
||||||
The indices of the placeholder embeddings that will be replaced by the
|
The indices of the placeholder embeddings that will be replaced by the
|
||||||
multimodal embeddings.
|
multimodal embeddings.
|
||||||
@ -288,7 +289,7 @@ class MultiModalPlaceholderMap:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(
|
def from_seq_group(
|
||||||
cls, seq_group: "SequenceGroupMetadata", positions: range
|
cls, seq_group: "SequenceGroupMetadata", positions: range
|
||||||
) -> Tuple[Optional[MultiModalDataDict], Dict[str,
|
) -> Tuple[Optional[MultiModalDataDict], dict[str,
|
||||||
"MultiModalPlaceholderMap"]]:
|
"MultiModalPlaceholderMap"]]:
|
||||||
"""
|
"""
|
||||||
Returns the multi-modal items that intersect with the portion of a
|
Returns the multi-modal items that intersect with the portion of a
|
||||||
@ -376,9 +377,9 @@ class MultiModalPlaceholderMap:
|
|||||||
def append_items_from_seq_group(
|
def append_items_from_seq_group(
|
||||||
self,
|
self,
|
||||||
positions: range,
|
positions: range,
|
||||||
multi_modal_items: List[_T],
|
multi_modal_items: list[_T],
|
||||||
multi_modal_placeholders: Sequence[PlaceholderRange],
|
multi_modal_placeholders: Sequence[PlaceholderRange],
|
||||||
) -> List[_T]:
|
) -> list[_T]:
|
||||||
"""
|
"""
|
||||||
Adds the multi-modal items that intersect ```positions`` to this
|
Adds the multi-modal items that intersect ```positions`` to this
|
||||||
placeholder map and returns the intersecting items.
|
placeholder map and returns the intersecting items.
|
||||||
@ -454,3 +455,22 @@ class MultiModalPlaceholderMap:
|
|||||||
|
|
||||||
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
|
return MultiModalPlaceholderMap.IndexMap(src=src_indices,
|
||||||
dest=dest_indices)
|
dest=dest_indices)
|
||||||
|
|
||||||
|
|
||||||
|
class MediaIO(ABC, Generic[_T]):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_bytes(self, data: bytes) -> _T:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_base64(self, media_type: str, data: str) -> _T:
|
||||||
|
"""
|
||||||
|
List of media types:
|
||||||
|
https://www.iana.org/assignments/media-types/media-types.xhtml
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_file(self, filepath: Path) -> _T:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
|
import base64
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,7 +12,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.transformers_utils.processor import get_image_processor
|
from vllm.transformers_utils.processor import get_image_processor
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .base import MultiModalPlugin
|
from .base import MediaIO, MultiModalPlugin
|
||||||
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
|
from .inputs import ImageItem, MultiModalData, MultiModalKwargs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -96,3 +99,39 @@ def rescale_image_size(image: Image.Image,
|
|||||||
if transpose >= 0:
|
if transpose >= 0:
|
||||||
image = image.transpose(Image.Transpose(transpose))
|
image = image.transpose(Image.Transpose(transpose))
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageMediaIO(MediaIO[Image.Image]):
|
||||||
|
|
||||||
|
def __init__(self, *, image_mode: str = "RGB") -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.image_mode = image_mode
|
||||||
|
|
||||||
|
def load_bytes(self, data: bytes) -> Image.Image:
|
||||||
|
image = Image.open(BytesIO(data))
|
||||||
|
image.load()
|
||||||
|
return image.convert(self.image_mode)
|
||||||
|
|
||||||
|
def load_base64(self, media_type: str, data: str) -> Image.Image:
|
||||||
|
return self.load_bytes(base64.b64decode(data))
|
||||||
|
|
||||||
|
def load_file(self, filepath: Path) -> Image.Image:
|
||||||
|
image = Image.open(filepath)
|
||||||
|
image.load()
|
||||||
|
return image.convert(self.image_mode)
|
||||||
|
|
||||||
|
def encode_base64(
|
||||||
|
self,
|
||||||
|
media: Image.Image,
|
||||||
|
*,
|
||||||
|
image_format: str = "JPEG",
|
||||||
|
) -> str:
|
||||||
|
image = media
|
||||||
|
|
||||||
|
with BytesIO() as buffer:
|
||||||
|
image = image.convert(self.image_mode)
|
||||||
|
image.save(buffer, image_format)
|
||||||
|
data = buffer.getvalue()
|
||||||
|
|
||||||
|
return base64.b64encode(data).decode('utf-8')
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import base64
|
|
||||||
import os
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from io import BytesIO
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, TypeVar, Union
|
from typing import Optional, TypeVar, Union
|
||||||
|
from urllib.parse import ParseResult, urlparse
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@ -10,283 +9,246 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import HTTPConnection, global_http_connection
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||||
from vllm.utils import PlaceholderModule
|
|
||||||
|
|
||||||
from .inputs import MultiModalDataDict, PlaceholderRange
|
from .audio import AudioMediaIO
|
||||||
|
from .base import MediaIO
|
||||||
try:
|
from .image import ImageMediaIO
|
||||||
import decord
|
from .inputs import PlaceholderRange
|
||||||
except ImportError:
|
from .video import VideoMediaIO
|
||||||
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
|
||||||
|
|
||||||
try:
|
|
||||||
import librosa
|
|
||||||
except ImportError:
|
|
||||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
|
||||||
|
|
||||||
try:
|
|
||||||
import soundfile
|
|
||||||
except ImportError:
|
|
||||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
cached_get_tokenizer = lru_cache(get_tokenizer)
|
cached_get_tokenizer = lru_cache(get_tokenizer)
|
||||||
|
|
||||||
|
_M = TypeVar("_M")
|
||||||
def _load_image_from_bytes(b: bytes) -> Image.Image:
|
|
||||||
image = Image.open(BytesIO(b))
|
|
||||||
image.load()
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
def _is_subpath(image_path: str, allowed_local_media_path: str) -> bool:
|
class MediaConnector:
|
||||||
# Get the common path
|
|
||||||
common_path = os.path.commonpath([
|
|
||||||
os.path.abspath(image_path),
|
|
||||||
os.path.abspath(allowed_local_media_path)
|
|
||||||
])
|
|
||||||
# Check if the common path is the same as allowed_local_media_path
|
|
||||||
return common_path == os.path.abspath(allowed_local_media_path)
|
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connection: HTTPConnection = global_http_connection,
|
||||||
|
*,
|
||||||
|
allowed_local_media_path: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def _load_image_from_file(image_url: str,
|
self.connection = connection
|
||||||
allowed_local_media_path: str) -> Image.Image:
|
|
||||||
if not allowed_local_media_path:
|
if allowed_local_media_path:
|
||||||
raise ValueError("Invalid 'image_url': Cannot load local files without"
|
allowed_local_media_path_ = Path(allowed_local_media_path)
|
||||||
"'--allowed-local-media-path'.")
|
|
||||||
if allowed_local_media_path:
|
if not allowed_local_media_path_.exists():
|
||||||
if not os.path.exists(allowed_local_media_path):
|
raise ValueError(
|
||||||
|
"Invalid `--allowed-local-media-path`: The path "
|
||||||
|
f"{allowed_local_media_path_} does not exist.")
|
||||||
|
if not allowed_local_media_path_.is_dir():
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid `--allowed-local-media-path`: The path "
|
||||||
|
f"{allowed_local_media_path_} must be a directory.")
|
||||||
|
else:
|
||||||
|
allowed_local_media_path_ = None
|
||||||
|
|
||||||
|
self.allowed_local_media_path = allowed_local_media_path_
|
||||||
|
|
||||||
|
def _load_data_url(
|
||||||
|
self,
|
||||||
|
url_spec: ParseResult,
|
||||||
|
media_io: MediaIO[_M],
|
||||||
|
) -> _M:
|
||||||
|
data_spec, data = url_spec.path.split(",", 1)
|
||||||
|
media_type, data_type = data_spec.split(";", 1)
|
||||||
|
|
||||||
|
if data_type != "base64":
|
||||||
|
msg = "Only base64 data URLs are supported for now."
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
return media_io.load_base64(media_type, data)
|
||||||
|
|
||||||
|
def _load_file_url(
|
||||||
|
self,
|
||||||
|
url_spec: ParseResult,
|
||||||
|
media_io: MediaIO[_M],
|
||||||
|
) -> _M:
|
||||||
|
allowed_local_media_path = self.allowed_local_media_path
|
||||||
|
if allowed_local_media_path is None:
|
||||||
|
raise RuntimeError("Cannot load local files without "
|
||||||
|
"`--allowed-local-media-path`.")
|
||||||
|
|
||||||
|
filepath = Path(url_spec.path)
|
||||||
|
if allowed_local_media_path not in filepath.resolve().parents:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid '--allowed-local-media-path': "
|
f"The file path {filepath} must be a subpath "
|
||||||
f"The path {allowed_local_media_path} does not exist.")
|
f"of `--allowed-local-media-path` {allowed_local_media_path}.")
|
||||||
if not os.path.isdir(allowed_local_media_path):
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid '--allowed-local-media-path': "
|
|
||||||
f"The path {allowed_local_media_path} must be a directory.")
|
|
||||||
|
|
||||||
# Only split once and assume the second part is the image path
|
return media_io.load_file(filepath)
|
||||||
_, image_path = image_url.split("file://", 1)
|
|
||||||
if not _is_subpath(image_path, allowed_local_media_path):
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid 'image_url': The file path {image_path} must"
|
|
||||||
" be a subpath of '--allowed-local-media-path'"
|
|
||||||
f" '{allowed_local_media_path}'.")
|
|
||||||
|
|
||||||
image = Image.open(image_path)
|
def load_from_url(
|
||||||
image.load()
|
self,
|
||||||
return image
|
url: str,
|
||||||
|
media_io: MediaIO[_M],
|
||||||
|
*,
|
||||||
|
fetch_timeout: Optional[int] = None,
|
||||||
|
) -> _M:
|
||||||
|
url_spec = urlparse(url)
|
||||||
|
|
||||||
|
if url_spec.scheme.startswith("http"):
|
||||||
|
connection = self.connection
|
||||||
|
data = connection.get_bytes(url, timeout=fetch_timeout)
|
||||||
|
|
||||||
def _load_image_from_data_url(image_url: str) -> Image.Image:
|
return media_io.load_bytes(data)
|
||||||
# Only split once and assume the second part is the base64 encoded image
|
|
||||||
_, image_base64 = image_url.split(",", 1)
|
|
||||||
return load_image_from_base64(image_base64)
|
|
||||||
|
|
||||||
|
if url_spec.scheme == "data":
|
||||||
|
return self._load_data_url(url_spec, media_io)
|
||||||
|
|
||||||
def fetch_image(image_url: str,
|
if url_spec.scheme == "file":
|
||||||
*,
|
return self._load_file_url(url_spec, media_io)
|
||||||
image_mode: str = "RGB",
|
|
||||||
allowed_local_media_path: str = "") -> Image.Image:
|
|
||||||
"""
|
|
||||||
Load a PIL image from a HTTP or base64 data URL.
|
|
||||||
|
|
||||||
By default, the image is converted into RGB format.
|
msg = "The URL must be either a HTTP, data or file URL."
|
||||||
"""
|
raise ValueError(msg)
|
||||||
if image_url.startswith('http'):
|
|
||||||
image_raw = global_http_connection.get_bytes(
|
|
||||||
image_url,
|
|
||||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
|
||||||
)
|
|
||||||
image = _load_image_from_bytes(image_raw)
|
|
||||||
|
|
||||||
elif image_url.startswith('data:image'):
|
async def load_from_url_async(
|
||||||
image = _load_image_from_data_url(image_url)
|
self,
|
||||||
elif image_url.startswith('file://'):
|
url: str,
|
||||||
image = _load_image_from_file(image_url, allowed_local_media_path)
|
media_io: MediaIO[_M],
|
||||||
else:
|
*,
|
||||||
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
|
fetch_timeout: Optional[int] = None,
|
||||||
"with either 'data:image', 'file://' or 'http'.")
|
) -> _M:
|
||||||
|
url_spec = urlparse(url)
|
||||||
|
|
||||||
return image.convert(image_mode)
|
if url_spec.scheme.startswith("http"):
|
||||||
|
connection = self.connection
|
||||||
|
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
||||||
|
|
||||||
|
return media_io.load_bytes(data)
|
||||||
|
|
||||||
async def async_fetch_image(image_url: str,
|
if url_spec.scheme == "data":
|
||||||
*,
|
return self._load_data_url(url_spec, media_io)
|
||||||
image_mode: str = "RGB",
|
|
||||||
allowed_local_media_path: str = "") -> Image.Image:
|
|
||||||
"""
|
|
||||||
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
|
||||||
|
|
||||||
By default, the image is converted into RGB format.
|
if url_spec.scheme == "file":
|
||||||
"""
|
return self._load_file_url(url_spec, media_io)
|
||||||
if image_url.startswith('http'):
|
|
||||||
image_raw = await global_http_connection.async_get_bytes(
|
|
||||||
image_url,
|
|
||||||
timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
|
||||||
)
|
|
||||||
image = _load_image_from_bytes(image_raw)
|
|
||||||
|
|
||||||
elif image_url.startswith('data:image'):
|
msg = "The URL must be either a HTTP, data or file URL."
|
||||||
image = _load_image_from_data_url(image_url)
|
raise ValueError(msg)
|
||||||
elif image_url.startswith('file://'):
|
|
||||||
image = _load_image_from_file(image_url, allowed_local_media_path)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
|
|
||||||
"with either 'data:image', 'file://' or 'http'.")
|
|
||||||
|
|
||||||
return image.convert(image_mode)
|
def fetch_audio(
|
||||||
|
self,
|
||||||
|
audio_url: str,
|
||||||
|
) -> tuple[np.ndarray, Union[int, float]]:
|
||||||
|
"""
|
||||||
|
Load audio from a URL.
|
||||||
|
"""
|
||||||
|
audio_io = AudioMediaIO()
|
||||||
|
|
||||||
|
return self.load_from_url(
|
||||||
def _load_video_from_bytes(b: bytes, num_frames: int = 32) -> npt.NDArray:
|
|
||||||
video_path = BytesIO(b)
|
|
||||||
vr = decord.VideoReader(video_path, num_threads=1)
|
|
||||||
total_frame_num = len(vr)
|
|
||||||
|
|
||||||
if total_frame_num > num_frames:
|
|
||||||
uniform_sampled_frames = np.linspace(0,
|
|
||||||
total_frame_num - 1,
|
|
||||||
num_frames,
|
|
||||||
dtype=int)
|
|
||||||
frame_idx = uniform_sampled_frames.tolist()
|
|
||||||
else:
|
|
||||||
frame_idx = [i for i in range(0, total_frame_num)]
|
|
||||||
frames = vr.get_batch(frame_idx).asnumpy()
|
|
||||||
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
def _load_video_from_data_url(video_url: str) -> npt.NDArray:
|
|
||||||
# Only split once and assume the second part is the base64 encoded video
|
|
||||||
_, video_base64 = video_url.split(",", 1)
|
|
||||||
|
|
||||||
if video_url.startswith("data:video/jpeg;"):
|
|
||||||
return np.stack([
|
|
||||||
np.array(load_image_from_base64(frame_base64))
|
|
||||||
for frame_base64 in video_base64.split(",")
|
|
||||||
])
|
|
||||||
|
|
||||||
return load_video_from_base64(video_base64)
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_video(video_url: str, *, num_frames: int = 32) -> npt.NDArray:
|
|
||||||
"""
|
|
||||||
Load video from a HTTP or base64 data URL.
|
|
||||||
"""
|
|
||||||
if video_url.startswith('http') or video_url.startswith('https'):
|
|
||||||
video_raw = global_http_connection.get_bytes(
|
|
||||||
video_url,
|
|
||||||
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
|
||||||
)
|
|
||||||
video = _load_video_from_bytes(video_raw, num_frames)
|
|
||||||
elif video_url.startswith('data:video'):
|
|
||||||
video = _load_video_from_data_url(video_url)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
|
|
||||||
"with either 'data:video' or 'http'.")
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
async def async_fetch_video(video_url: str,
|
|
||||||
*,
|
|
||||||
num_frames: int = 32) -> npt.NDArray:
|
|
||||||
"""
|
|
||||||
Asynchronously load video from a HTTP or base64 data URL.
|
|
||||||
|
|
||||||
By default, the image is converted into RGB format.
|
|
||||||
"""
|
|
||||||
if video_url.startswith('http') or video_url.startswith('https'):
|
|
||||||
video_raw = await global_http_connection.async_get_bytes(
|
|
||||||
video_url,
|
|
||||||
timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
|
||||||
)
|
|
||||||
video = _load_video_from_bytes(video_raw, num_frames)
|
|
||||||
elif video_url.startswith('data:video'):
|
|
||||||
video = _load_video_from_data_url(video_url)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid 'video_url': A valid 'video_url' must start "
|
|
||||||
"with either 'data:video' or 'http'.")
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
|
||||||
"""
|
|
||||||
Load audio from a URL.
|
|
||||||
"""
|
|
||||||
if audio_url.startswith("http"):
|
|
||||||
audio_bytes = global_http_connection.get_bytes(
|
|
||||||
audio_url,
|
audio_url,
|
||||||
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
audio_io,
|
||||||
|
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||||
)
|
)
|
||||||
elif audio_url.startswith("data:audio"):
|
|
||||||
_, audio_base64 = audio_url.split(",", 1)
|
|
||||||
audio_bytes = base64.b64decode(audio_base64)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
|
|
||||||
"with either 'data:audio' or 'http'.")
|
|
||||||
|
|
||||||
return librosa.load(BytesIO(audio_bytes), sr=None)
|
async def fetch_audio_async(
|
||||||
|
self,
|
||||||
|
audio_url: str,
|
||||||
|
) -> tuple[np.ndarray, Union[int, float]]:
|
||||||
|
"""
|
||||||
|
Asynchronously fetch audio from a URL.
|
||||||
|
"""
|
||||||
|
audio_io = AudioMediaIO()
|
||||||
|
|
||||||
|
return await self.load_from_url_async(
|
||||||
async def async_fetch_audio(
|
|
||||||
audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
|
|
||||||
"""
|
|
||||||
Asynchronously fetch audio from a URL.
|
|
||||||
"""
|
|
||||||
if audio_url.startswith("http"):
|
|
||||||
audio_bytes = await global_http_connection.async_get_bytes(
|
|
||||||
audio_url,
|
audio_url,
|
||||||
timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
audio_io,
|
||||||
|
fetch_timeout=envs.VLLM_AUDIO_FETCH_TIMEOUT,
|
||||||
)
|
)
|
||||||
elif audio_url.startswith("data:audio"):
|
|
||||||
_, audio_base64 = audio_url.split(",", 1)
|
|
||||||
audio_bytes = base64.b64decode(audio_base64)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
|
|
||||||
"with either 'data:audio' or 'http'.")
|
|
||||||
|
|
||||||
return librosa.load(BytesIO(audio_bytes), sr=None)
|
def fetch_image(
|
||||||
|
self,
|
||||||
|
|
||||||
def get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
|
||||||
audio, sr = fetch_audio(audio_url)
|
|
||||||
return {"audio": (audio, sr)}
|
|
||||||
|
|
||||||
|
|
||||||
def get_and_parse_image(
|
|
||||||
image_url: str,
|
image_url: str,
|
||||||
*,
|
*,
|
||||||
allowed_local_media_path: str = "") -> MultiModalDataDict:
|
image_mode: str = "RGB",
|
||||||
image = fetch_image(image_url,
|
) -> Image.Image:
|
||||||
allowed_local_media_path=allowed_local_media_path)
|
"""
|
||||||
return {"image": image}
|
Load a PIL image from a HTTP or base64 data URL.
|
||||||
|
|
||||||
|
By default, the image is converted into RGB format.
|
||||||
|
"""
|
||||||
|
image_io = ImageMediaIO(image_mode=image_mode)
|
||||||
|
|
||||||
def get_and_parse_video(video_url: str) -> MultiModalDataDict:
|
return self.load_from_url(
|
||||||
video = fetch_video(video_url)
|
image_url,
|
||||||
return {"video": video}
|
image_io,
|
||||||
|
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_image_async(
|
||||||
async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
|
self,
|
||||||
audio, sr = await async_fetch_audio(audio_url)
|
|
||||||
return {"audio": (audio, sr)}
|
|
||||||
|
|
||||||
|
|
||||||
async def async_get_and_parse_image(
|
|
||||||
image_url: str,
|
image_url: str,
|
||||||
*,
|
*,
|
||||||
allowed_local_media_path: str = "") -> MultiModalDataDict:
|
image_mode: str = "RGB",
|
||||||
image = await async_fetch_image(
|
) -> Image.Image:
|
||||||
image_url, allowed_local_media_path=allowed_local_media_path)
|
"""
|
||||||
return {"image": image}
|
Asynchronously load a PIL image from a HTTP or base64 data URL.
|
||||||
|
|
||||||
|
By default, the image is converted into RGB format.
|
||||||
|
"""
|
||||||
|
image_io = ImageMediaIO(image_mode=image_mode)
|
||||||
|
|
||||||
|
return await self.load_from_url_async(
|
||||||
|
image_url,
|
||||||
|
image_io,
|
||||||
|
fetch_timeout=envs.VLLM_IMAGE_FETCH_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fetch_video(
|
||||||
|
self,
|
||||||
|
video_url: str,
|
||||||
|
*,
|
||||||
|
image_mode: str = "RGB",
|
||||||
|
num_frames: int = 32,
|
||||||
|
) -> npt.NDArray:
|
||||||
|
"""
|
||||||
|
Load video from a HTTP or base64 data URL.
|
||||||
|
"""
|
||||||
|
image_io = ImageMediaIO(image_mode=image_mode)
|
||||||
|
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
||||||
|
|
||||||
|
return self.load_from_url(
|
||||||
|
video_url,
|
||||||
|
video_io,
|
||||||
|
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_video_async(
|
||||||
|
self,
|
||||||
|
video_url: str,
|
||||||
|
*,
|
||||||
|
image_mode: str = "RGB",
|
||||||
|
num_frames: int = 32,
|
||||||
|
) -> npt.NDArray:
|
||||||
|
"""
|
||||||
|
Asynchronously load video from a HTTP or base64 data URL.
|
||||||
|
|
||||||
|
By default, the image is converted into RGB format.
|
||||||
|
"""
|
||||||
|
image_io = ImageMediaIO(image_mode=image_mode)
|
||||||
|
video_io = VideoMediaIO(image_io, num_frames=num_frames)
|
||||||
|
|
||||||
|
return await self.load_from_url_async(
|
||||||
|
video_url,
|
||||||
|
video_io,
|
||||||
|
fetch_timeout=envs.VLLM_VIDEO_FETCH_TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def async_get_and_parse_video(video_url: str) -> MultiModalDataDict:
|
global_media_connector = MediaConnector()
|
||||||
video = await async_fetch_video(video_url)
|
"""The global :class:`MediaConnector` instance used by vLLM."""
|
||||||
return {"video": video}
|
|
||||||
|
fetch_audio = global_media_connector.fetch_audio
|
||||||
|
fetch_image = global_media_connector.fetch_image
|
||||||
|
fetch_video = global_media_connector.fetch_video
|
||||||
|
|
||||||
|
|
||||||
def encode_audio_base64(
|
def encode_audio_base64(
|
||||||
@ -294,10 +256,8 @@ def encode_audio_base64(
|
|||||||
sampling_rate: int,
|
sampling_rate: int,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Encode audio as base64."""
|
"""Encode audio as base64."""
|
||||||
buffered = BytesIO()
|
audio_io = AudioMediaIO()
|
||||||
soundfile.write(buffered, audio, sampling_rate, format="WAV")
|
return audio_io.encode_base64((audio, sampling_rate))
|
||||||
|
|
||||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
def encode_image_base64(
|
def encode_image_base64(
|
||||||
@ -311,29 +271,14 @@ def encode_image_base64(
|
|||||||
|
|
||||||
By default, the image is converted into RGB format before being encoded.
|
By default, the image is converted into RGB format before being encoded.
|
||||||
"""
|
"""
|
||||||
buffered = BytesIO()
|
image_io = ImageMediaIO(image_mode=image_mode)
|
||||||
image = image.convert(image_mode)
|
return image_io.encode_base64(image, image_format=format)
|
||||||
image.save(buffered, format)
|
|
||||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
|
|
||||||
"""Load image from base64 format."""
|
|
||||||
return _load_image_from_bytes(base64.b64decode(image))
|
|
||||||
|
|
||||||
|
|
||||||
def encode_video_base64(frames: npt.NDArray) -> str:
|
def encode_video_base64(frames: npt.NDArray) -> str:
|
||||||
base64_frames = []
|
image_io = ImageMediaIO()
|
||||||
frames_list = [frames[i] for i in range(frames.shape[0])]
|
video_io = VideoMediaIO(image_io)
|
||||||
for frame in frames_list:
|
return video_io.encode_base64(frames)
|
||||||
img_base64 = encode_image_base64(Image.fromarray(frame))
|
|
||||||
base64_frames.append(img_base64)
|
|
||||||
return ",".join(base64_frames)
|
|
||||||
|
|
||||||
|
|
||||||
def load_video_from_base64(video: Union[bytes, str]) -> npt.NDArray:
|
|
||||||
"""Load video from base64 format."""
|
|
||||||
return _load_video_from_bytes(base64.b64decode(video))
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_visual_encoder_outputs(
|
def resolve_visual_encoder_outputs(
|
||||||
@ -389,7 +334,7 @@ def repeat_and_pad_token(
|
|||||||
repeat_count: int = 1,
|
repeat_count: int = 1,
|
||||||
pad_token_left: Optional[_T] = None,
|
pad_token_left: Optional[_T] = None,
|
||||||
pad_token_right: Optional[_T] = None,
|
pad_token_right: Optional[_T] = None,
|
||||||
) -> List[_T]:
|
) -> list[_T]:
|
||||||
replacement = [token] * repeat_count
|
replacement = [token] * repeat_count
|
||||||
if pad_token_left is not None:
|
if pad_token_left is not None:
|
||||||
replacement = [pad_token_left] + replacement
|
replacement = [pad_token_left] + replacement
|
||||||
@ -402,13 +347,13 @@ def repeat_and_pad_token(
|
|||||||
def repeat_and_pad_placeholder_tokens(
|
def repeat_and_pad_placeholder_tokens(
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: list[int],
|
||||||
*,
|
*,
|
||||||
placeholder_token_id: int,
|
placeholder_token_id: int,
|
||||||
repeat_count: Union[int, List[int]],
|
repeat_count: Union[int, list[int]],
|
||||||
pad_token_left: Optional[int] = None,
|
pad_token_left: Optional[int] = None,
|
||||||
pad_token_right: Optional[int] = None,
|
pad_token_right: Optional[int] = None,
|
||||||
) -> Tuple[Optional[str], List[int], List[PlaceholderRange]]:
|
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
|
||||||
if isinstance(repeat_count, int):
|
if isinstance(repeat_count, int):
|
||||||
repeat_count = [repeat_count]
|
repeat_count = [repeat_count]
|
||||||
|
|
||||||
@ -450,8 +395,8 @@ def repeat_and_pad_placeholder_tokens(
|
|||||||
new_prompt += prompt_parts[i] + replacement_str
|
new_prompt += prompt_parts[i] + replacement_str
|
||||||
new_prompt += prompt_parts[-1]
|
new_prompt += prompt_parts[-1]
|
||||||
|
|
||||||
new_token_ids: List[int] = []
|
new_token_ids = list[int]()
|
||||||
placeholder_ranges: List[PlaceholderRange] = []
|
placeholder_ranges = list[PlaceholderRange]()
|
||||||
placeholder_token_idx = 0
|
placeholder_token_idx = 0
|
||||||
for i, token in enumerate(prompt_token_ids):
|
for i, token in enumerate(prompt_token_ids):
|
||||||
if token == placeholder_token_id:
|
if token == placeholder_token_id:
|
||||||
@ -481,7 +426,7 @@ def repeat_and_pad_placeholder_tokens(
|
|||||||
def consecutive_placeholder_ranges(
|
def consecutive_placeholder_ranges(
|
||||||
num_items: int,
|
num_items: int,
|
||||||
item_size: int,
|
item_size: int,
|
||||||
initial_offset: int = 0) -> List[PlaceholderRange]:
|
initial_offset: int = 0) -> list[PlaceholderRange]:
|
||||||
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
|
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|||||||
@ -1,23 +1,32 @@
|
|||||||
from functools import lru_cache
|
import base64
|
||||||
|
from functools import lru_cache, partial
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.inputs.registry import InputContext
|
from vllm.inputs.registry import InputContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.processor import get_video_processor
|
from vllm.transformers_utils.processor import get_video_processor
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import PlaceholderModule, is_list_of
|
||||||
|
|
||||||
from .base import MultiModalData
|
from .base import MediaIO, MultiModalData
|
||||||
from .image import ImagePlugin
|
from .image import ImageMediaIO, ImagePlugin
|
||||||
from .inputs import MultiModalKwargs, VideoItem
|
from .inputs import MultiModalKwargs, VideoItem
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
|
try:
|
||||||
|
import decord
|
||||||
|
except ImportError:
|
||||||
|
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
cached_get_video_processor = lru_cache(get_video_processor)
|
cached_get_video_processor = lru_cache(get_video_processor)
|
||||||
@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray,
|
|||||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||||
sampled_frames = frames[frame_indices, ...]
|
sampled_frames = frames[frame_indices, ...]
|
||||||
return sampled_frames
|
return sampled_frames
|
||||||
|
|
||||||
|
|
||||||
|
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_io: ImageMediaIO,
|
||||||
|
*,
|
||||||
|
num_frames: int = 32,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.image_io = image_io
|
||||||
|
self.num_frames = num_frames
|
||||||
|
|
||||||
|
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||||
|
vr = decord.VideoReader(BytesIO(data), num_threads=1)
|
||||||
|
total_frame_num = len(vr)
|
||||||
|
|
||||||
|
num_frames = self.num_frames
|
||||||
|
if total_frame_num > num_frames:
|
||||||
|
uniform_sampled_frames = np.linspace(0,
|
||||||
|
total_frame_num - 1,
|
||||||
|
num_frames,
|
||||||
|
dtype=int)
|
||||||
|
frame_idx = uniform_sampled_frames.tolist()
|
||||||
|
else:
|
||||||
|
frame_idx = list(range(0, total_frame_num))
|
||||||
|
|
||||||
|
return vr.get_batch(frame_idx).asnumpy()
|
||||||
|
|
||||||
|
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
||||||
|
if media_type.lower() == "video/jpeg":
|
||||||
|
load_frame = partial(
|
||||||
|
self.image_io.load_base64,
|
||||||
|
"image/jpeg",
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.stack([
|
||||||
|
np.array(load_frame(frame_data))
|
||||||
|
for frame_data in data.split(",")
|
||||||
|
])
|
||||||
|
|
||||||
|
return self.load_bytes(base64.b64decode(data))
|
||||||
|
|
||||||
|
def load_file(self, filepath: Path) -> npt.NDArray:
|
||||||
|
with filepath.open("rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
|
||||||
|
return self.load_bytes(data)
|
||||||
|
|
||||||
|
def encode_base64(
|
||||||
|
self,
|
||||||
|
media: npt.NDArray,
|
||||||
|
*,
|
||||||
|
video_format: str = "JPEG",
|
||||||
|
) -> str:
|
||||||
|
video = media
|
||||||
|
|
||||||
|
if video_format == "JPEG":
|
||||||
|
encode_frame = partial(
|
||||||
|
self.image_io.encode_base64,
|
||||||
|
image_format=video_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ",".join(
|
||||||
|
encode_frame(Image.fromarray(frame)) for frame in video)
|
||||||
|
|
||||||
|
msg = "Only JPEG format is supported for now."
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user