[Frontend] Support configurable mm placeholder strings & flexible video sampling policies via CLI flags. (#20105)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
Chenheli Hua 2025-07-01 23:34:03 -07:00 committed by GitHub
parent 7da296be04
commit 2e7cbf2d7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 199 additions and 29 deletions

View File

@ -6,8 +6,8 @@ import os
import uuid import uuid
from asyncio import CancelledError from asyncio import CancelledError
from copy import copy from copy import copy
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Any, Optional
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -32,6 +32,8 @@ class RequestOutput:
@dataclass @dataclass
class MockModelConfig: class MockModelConfig:
use_async_output_proc = True use_async_output_proc = True
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
class MockEngine: class MockEngine:

View File

@ -231,6 +231,58 @@ def test_limit_mm_per_prompt_parser(arg, expected):
assert args.limit_mm_per_prompt == expected assert args.limit_mm_per_prompt == expected
@pytest.mark.parametrize(
("arg", "expected"),
[
(None, dict()),
('{"video": {"num_frames": 123} }', {
"video": {
"num_frames": 123
}
}),
(
'{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }', # noqa
{
"video": {
"num_frames": 123,
"fps": 1.0,
"foo": "bar"
},
"image": {
"foo": "bar"
}
}),
])
def test_media_io_kwargs_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--media-io-kwargs", arg])
assert args.media_io_kwargs == expected
@pytest.mark.parametrize(("arg", "expected"), [
(None, dict()),
('{"video":"<|video_placeholder|>"}', {
"video": "<|video_placeholder|>"
}),
('{"video":"<|video_placeholder|>", "image": "<|image_placeholder|>"}', {
"video": "<|video_placeholder|>",
"image": "<|image_placeholder|>"
}),
])
def test_mm_placeholder_str_override_parser(arg, expected):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--mm-placeholder-str-override", arg])
assert args.mm_placeholder_str_override == expected
def test_compilation_config(): def test_compilation_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())

View File

@ -3,8 +3,8 @@
import asyncio import asyncio
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Optional from typing import Any, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
from vllm.config import MultiModalConfig from vllm.config import MultiModalConfig
@ -40,6 +40,8 @@ class MockModelConfig:
allowed_local_media_path: str = "" allowed_local_media_path: str = ""
encoder_config = None encoder_config = None
generation_config: str = "auto" generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
def get_diff_sampling_param(self): def get_diff_sampling_param(self):
return self.diff_sampling_param or {} return self.diff_sampling_param or {}

View File

@ -167,14 +167,14 @@ async def test_fetch_image_error_conversion():
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800]) @pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_fetch_video_http(video_url: str, num_frames: int): async def test_fetch_video_http(video_url: str, num_frames: int):
connector = MediaConnector() connector = MediaConnector(
media_io_kwargs={"video": {
"num_frames": num_frames,
}})
video_sync = connector.fetch_video(video_url, num_frames=num_frames) video_sync = connector.fetch_video(video_url)
video_async = await connector.fetch_video_async(video_url, video_async = await connector.fetch_video_async(video_url)
num_frames=num_frames)
# Check that the video frames are equal and metadata are same
assert np.array_equal(video_sync[0], video_async[0]) assert np.array_equal(video_sync[0], video_async[0])
assert video_sync[1] == video_async[1]
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`. # Used for the next two tests related to `merge_and_sort_multimodal_metadata`.

View File

@ -4,7 +4,10 @@ import numpy as np
import numpy.typing as npt import numpy.typing as npt
import pytest import pytest
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader from vllm import envs
from vllm.multimodal.image import ImageMediaIO
from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader,
VideoMediaIO)
NUM_FRAMES = 10 NUM_FRAMES = 10
FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3) FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
@ -40,3 +43,46 @@ def test_video_loader_registry():
def test_video_loader_type_doesnt_exist(): def test_video_loader_type_doesnt_exist():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
VIDEO_LOADER_REGISTRY.load("non_existing_video_loader") VIDEO_LOADER_REGISTRY.load("non_existing_video_loader")
@VIDEO_LOADER_REGISTRY.register("assert_10_frames_1_fps")
class Assert10Frames1FPSVideoLoader(VideoLoader):
@classmethod
def load_bytes(cls,
data: bytes,
num_frames: int = -1,
fps: float = -1.0,
**kwargs) -> npt.NDArray:
assert num_frames == 10, "bad num_frames"
assert fps == 1.0, "bad fps"
return FAKE_OUTPUT_2
def test_video_media_io_kwargs():
envs.VLLM_VIDEO_LOADER_BACKEND = "assert_10_frames_1_fps"
imageio = ImageMediaIO()
# Verify that different args pass/fail assertions as expected.
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0})
_ = videoio.load_bytes(b"test")
videoio = VideoMediaIO(
imageio, **{
"num_frames": 10,
"fps": 1.0,
"not_used": "not_used"
})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad num_frames"):
videoio = VideoMediaIO(imageio, **{})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad num_frames"):
videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0})
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad fps"):
videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0})
_ = videoio.load_bytes(b"test")

View File

@ -346,6 +346,12 @@ class ModelConfig:
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict) limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
"""Maximum number of data items per modality per prompt. Only applicable """Maximum number of data items per modality per prompt. Only applicable
for multimodal models.""" for multimodal models."""
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
"""Optionally override placeholder string for given modalities."""
use_async_output_proc: bool = True use_async_output_proc: bool = True
"""Whether to use async output processor.""" """Whether to use async output processor."""
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
@ -694,6 +700,8 @@ class ModelConfig:
if self.registry.is_multimodal_model(self.architectures): if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig( return MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt, limit_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs,
mm_placeholder_str_override=self.mm_placeholder_str_override,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
disable_mm_preprocessor_cache=self. disable_mm_preprocessor_cache=self.
disable_mm_preprocessor_cache) disable_mm_preprocessor_cache)
@ -3063,6 +3071,14 @@ class MultiModalConfig:
`{"images": 16, "videos": 2}` `{"images": 16, "videos": 2}`
""" """
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Additional args passed to process media inputs, keyed by modalities.
For example, to set num_frames for video, set
`--media-io-kwargs '{"video": {"num_frames": 40} }'` """
mm_placeholder_str_override: dict[str, str] = field(default_factory=dict)
"""Optionally override placeholder string for given modalities."""
mm_processor_kwargs: Optional[dict[str, object]] = None mm_processor_kwargs: Optional[dict[str, object]] = None
""" """
Overrides for the multi-modal processor obtained from Overrides for the multi-modal processor obtained from

View File

@ -369,6 +369,11 @@ class EngineArgs:
get_field(TokenizerPoolConfig, "extra_config") get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: dict[str, int] = \ limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt") get_field(MultiModalConfig, "limit_per_prompt")
media_io_kwargs: dict[str, dict[str,
Any]] = get_field(MultiModalConfig,
"media_io_kwargs")
mm_placeholder_str_override: dict[str, str] = \
get_field(MultiModalConfig, "mm_placeholder_str_override")
mm_processor_kwargs: Optional[Dict[str, Any]] = \ mm_processor_kwargs: Optional[Dict[str, Any]] = \
MultiModalConfig.mm_processor_kwargs MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = \ disable_mm_preprocessor_cache: bool = \
@ -745,6 +750,11 @@ class EngineArgs:
) )
multimodal_group.add_argument("--limit-mm-per-prompt", multimodal_group.add_argument("--limit-mm-per-prompt",
**multimodal_kwargs["limit_per_prompt"]) **multimodal_kwargs["limit_per_prompt"])
multimodal_group.add_argument("--media-io-kwargs",
**multimodal_kwargs["media_io_kwargs"])
multimodal_group.add_argument(
"--mm-placeholder-str-override",
**multimodal_kwargs["mm_placeholder_str_override"])
multimodal_group.add_argument( multimodal_group.add_argument(
"--mm-processor-kwargs", "--mm-processor-kwargs",
**multimodal_kwargs["mm_processor_kwargs"]) **multimodal_kwargs["mm_processor_kwargs"])
@ -969,6 +979,8 @@ class EngineArgs:
enable_prompt_embeds=self.enable_prompt_embeds, enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs,
mm_placeholder_str_override=self.mm_placeholder_str_override,
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,

View File

@ -507,6 +507,9 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def _placeholder_str(self, modality: ModalityStr, def _placeholder_str(self, modality: ModalityStr,
current_count: int) -> Optional[str]: current_count: int) -> Optional[str]:
if modality in self._model_config.mm_placeholder_str_override:
return self._model_config.mm_placeholder_str_override[modality]
# TODO: Let user specify how to insert image tokens into prompt # TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template) # (similar to chat template)
hf_config = self._model_config.hf_config hf_config = self._model_config.hf_config
@ -725,6 +728,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
) )
@ -763,7 +767,7 @@ class MultiModalContentParser(BaseMultiModalContentParser):
return self.parse_audio(audio_url) return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str) -> None:
video = self._connector.fetch_video(video_url) video = self._connector.fetch_video(video_url=video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)
@ -776,7 +780,8 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector( self._connector = MediaConnector(
allowed_local_media_path=tracker.allowed_local_media_path, media_io_kwargs=self._tracker._model_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path
) )
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str) -> None:
@ -818,7 +823,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
return self.parse_audio(audio_url) return self.parse_audio(audio_url)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str) -> None:
video = self._connector.fetch_video_async(video_url) video = self._connector.fetch_video_async(video_url=video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video)
self._add_placeholder(placeholder) self._add_placeholder(placeholder)

View File

@ -83,6 +83,16 @@ class AudioResampler:
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
def __init__(self, **kwargs) -> None:
super().__init__()
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]: def load_bytes(self, data: bytes) -> tuple[npt.NDArray, float]:
return librosa.load(BytesIO(data), sr=None) return librosa.load(BytesIO(data), sr=None)

View File

@ -44,10 +44,16 @@ def convert_image_mode(image: Image.Image, to_mode: str):
class ImageMediaIO(MediaIO[Image.Image]): class ImageMediaIO(MediaIO[Image.Image]):
def __init__(self, *, image_mode: str = "RGB") -> None: def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
super().__init__() super().__init__()
self.image_mode = image_mode self.image_mode = image_mode
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
def load_bytes(self, data: bytes) -> Image.Image: def load_bytes(self, data: bytes) -> Image.Image:
image = Image.open(BytesIO(data)) image = Image.open(BytesIO(data))

View File

@ -38,12 +38,15 @@ class MediaConnector:
def __init__( def __init__(
self, self,
media_io_kwargs: Optional[dict[str, dict[str, Any]]] = None,
connection: HTTPConnection = global_http_connection, connection: HTTPConnection = global_http_connection,
*, *,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.media_io_kwargs: dict[str, dict[
str, Any]] = media_io_kwargs if media_io_kwargs else {}
self.connection = connection self.connection = connection
if allowed_local_media_path: if allowed_local_media_path:
@ -149,7 +152,7 @@ class MediaConnector:
""" """
Load audio from a URL. Load audio from a URL.
""" """
audio_io = AudioMediaIO() audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
return self.load_from_url( return self.load_from_url(
audio_url, audio_url,
@ -164,7 +167,7 @@ class MediaConnector:
""" """
Asynchronously fetch audio from a URL. Asynchronously fetch audio from a URL.
""" """
audio_io = AudioMediaIO() audio_io = AudioMediaIO(**self.media_io_kwargs.get("audio", {}))
return await self.load_from_url_async( return await self.load_from_url_async(
audio_url, audio_url,
@ -183,7 +186,8 @@ class MediaConnector:
By default, the image is converted into RGB format. By default, the image is converted into RGB format.
""" """
image_io = ImageMediaIO(image_mode=image_mode) image_io = ImageMediaIO(image_mode=image_mode,
**self.media_io_kwargs.get("image", {}))
try: try:
return self.load_from_url( return self.load_from_url(
@ -206,7 +210,8 @@ class MediaConnector:
By default, the image is converted into RGB format. By default, the image is converted into RGB format.
""" """
image_io = ImageMediaIO(image_mode=image_mode) image_io = ImageMediaIO(image_mode=image_mode,
**self.media_io_kwargs.get("image", {}))
try: try:
return await self.load_from_url_async( return await self.load_from_url_async(
@ -223,13 +228,14 @@ class MediaConnector:
video_url: str, video_url: str,
*, *,
image_mode: str = "RGB", image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray: ) -> npt.NDArray:
""" """
Load video from a HTTP or base64 data URL. Load video from a HTTP or base64 data URL.
""" """
image_io = ImageMediaIO(image_mode=image_mode) image_io = ImageMediaIO(image_mode=image_mode,
video_io = VideoMediaIO(image_io, num_frames=num_frames) **self.media_io_kwargs.get("image", {}))
video_io = VideoMediaIO(image_io,
**self.media_io_kwargs.get("video", {}))
return self.load_from_url( return self.load_from_url(
video_url, video_url,
@ -242,15 +248,16 @@ class MediaConnector:
video_url: str, video_url: str,
*, *,
image_mode: str = "RGB", image_mode: str = "RGB",
num_frames: int = 32,
) -> npt.NDArray: ) -> npt.NDArray:
""" """
Asynchronously load video from a HTTP or base64 data URL. Asynchronously load video from a HTTP or base64 data URL.
By default, the image is converted into RGB format. By default, the image is converted into RGB format.
""" """
image_io = ImageMediaIO(image_mode=image_mode) image_io = ImageMediaIO(image_mode=image_mode,
video_io = VideoMediaIO(image_io, num_frames=num_frames) **self.media_io_kwargs.get("image", {}))
video_io = VideoMediaIO(image_io,
**self.media_io_kwargs.get("video", {}))
return await self.load_from_url_async( return await self.load_from_url_async(
video_url, video_url,

View File

@ -54,7 +54,10 @@ class VideoLoader:
@classmethod @classmethod
@abstractmethod @abstractmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: def load_bytes(cls,
data: bytes,
num_frames: int = -1,
**kwargs) -> npt.NDArray:
raise NotImplementedError raise NotImplementedError
@ -102,7 +105,8 @@ class OpenCVVideoBackend(VideoLoader):
@classmethod @classmethod
def load_bytes(cls, def load_bytes(cls,
data: bytes, data: bytes,
num_frames: int = -1) -> tuple[npt.NDArray, dict]: num_frames: int = -1,
**kwargs) -> npt.NDArray:
import cv2 import cv2
backend = cls().get_cv2_video_api() backend = cls().get_cv2_video_api()
@ -159,18 +163,26 @@ class VideoMediaIO(MediaIO[npt.NDArray]):
def __init__( def __init__(
self, self,
image_io: ImageMediaIO, image_io: ImageMediaIO,
*,
num_frames: int = 32, num_frames: int = 32,
**kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
self.image_io = image_io self.image_io = image_io
self.num_frames = num_frames self.num_frames = num_frames
# `kwargs` contains custom arguments from
# --media-io-kwargs for this modality.
# They can be passed to the underlying
# media loaders (e.g. custom implementations)
# for flexible control.
self.kwargs = kwargs
video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND video_loader_backend = envs.VLLM_VIDEO_LOADER_BACKEND
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend) self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
def load_bytes(self, data: bytes) -> npt.NDArray: def load_bytes(self, data: bytes) -> npt.NDArray:
return self.video_loader.load_bytes(data, self.num_frames) return self.video_loader.load_bytes(data,
num_frames=self.num_frames,
**self.kwargs)
def load_base64(self, media_type: str, data: str) -> npt.NDArray: def load_base64(self, media_type: str, data: str) -> npt.NDArray:
if media_type.lower() == "video/jpeg": if media_type.lower() == "video/jpeg":