Feature/video support in random mm dataset (#25963)

Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com>
Signed-off-by: Eugene Khvedchenya <ekhvedchenia@nvidia.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Eugene Khvedchenya 2025-10-29 12:24:52 +02:00 committed by GitHub
parent 1a33aacf82
commit 5e72216d17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 601 additions and 25 deletions

View File

@ -359,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
assert len(mm_data) >= 1
for it in mm_data:
assert it.get("type") == "image_url"
@pytest.mark.benchmark
def test_random_mm_video_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
"""Test video sampling functionality in RandomMultiModalDataset."""
ds = RandomMultiModalDataset(random_seed=42)
# Test with video bucket configuration
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=5,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
assert len(samples) == 5
# Check that we have both images and videos
video_count = 0
image_count = 0
for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
if item.get("type") == "video_url":
video_count += 1
# Verify video URL format
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")
elif item.get("type") == "image_url":
image_count += 1
# Verify image URL format
url = item.get("image_url", {}).get("url", "")
assert url.startswith("data:image/jpeg;base64,")
# Should have some videos due to 0.7 probability
assert video_count > 0
assert image_count > 0
@pytest.mark.benchmark
def test_random_mm_video_only_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
"""Test sampling with only video buckets."""
ds = RandomMultiModalDataset(random_seed=42)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
samples = _collect_mm_samples(
ds,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
assert len(samples) == 3
for s in samples:
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
assert item.get("type") == "video_url"
url = item.get("video_url", {}).get("url", "")
assert url.startswith("data:video/mp4;base64,")
@pytest.mark.benchmark
def test_random_mm_video_deterministic_sampling(
hf_tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Test that video sampling is deterministic with same seed."""
seed = 123
ds_a = RandomMultiModalDataset(random_seed=seed)
ds_b = RandomMultiModalDataset(random_seed=seed)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
a = _collect_mm_samples(
ds_a,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
b = _collect_mm_samples(
ds_b,
hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
)
fa = [_mm_fingerprint_sample(s) for s in a]
fb = [_mm_fingerprint_sample(s) for s in b]
assert fa == fb

View File

@ -0,0 +1,398 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
from tempfile import NamedTemporaryFile
from typing import Any, cast
import cv2
import pytest
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.benchmarks.datasets import RandomMultiModalDataset, SampleRequest
@pytest.fixture(scope="session")
def hf_tokenizer() -> PreTrainedTokenizerBase:
"""Use a small, commonly available tokenizer."""
return AutoTokenizer.from_pretrained("gpt2")
@pytest.fixture
def video_dataset() -> RandomMultiModalDataset:
"""Create a RandomMultiModalDataset instance for testing."""
return RandomMultiModalDataset(random_seed=42)
@pytest.mark.benchmark
def test_generate_synthetic_video_different_seeds():
"""Test that different seeds produce different videos."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=456)
width, height, num_frames = 64, 48, 8
video1 = dataset1.generate_synthetic_video(width, height, num_frames)
video2 = dataset2.generate_synthetic_video(width, height, num_frames)
# Videos should be different due to different seeds
assert video1["bytes"] != video2["bytes"]
@pytest.mark.benchmark
def test_map_config_to_modality(video_dataset: RandomMultiModalDataset):
"""Test modality mapping for different configurations."""
# Test image configuration (num_frames = 1)
assert video_dataset.map_config_to_modality((256, 256, 1)) == "image"
assert video_dataset.map_config_to_modality((720, 1280, 1)) == "image"
# Test video configurations (num_frames > 1)
assert video_dataset.map_config_to_modality((256, 256, 8)) == "video"
assert video_dataset.map_config_to_modality((720, 1280, 16)) == "video"
assert video_dataset.map_config_to_modality((64, 64, 32)) == "video"
# Test invalid configurations
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.map_config_to_modality((256, 256, 0))
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.map_config_to_modality((256, 256, -1))
@pytest.mark.benchmark
def test_generate_mm_item_video(video_dataset: RandomMultiModalDataset):
"""Test generating multimodal items for video configurations."""
# Test video item generation
video_config = (64, 48, 8) # height, width, num_frames
result = video_dataset.generate_mm_item(video_config)
# Check the result structure matches OpenAI API format
assert isinstance(result, dict)
assert result["type"] == "video_url"
assert "video_url" in result
assert "url" in result["video_url"]
# Check that the URL is a data URL with base64 encoded video
url = result["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
# Decode and verify the video content
base64_data = url.split(",")[1]
video_bytes = base64.b64decode(base64_data)
assert len(video_bytes) > 0
# Verify the video can be decoded
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(video_bytes)
try:
cap = cv2.VideoCapture(temp_path)
assert cap.isOpened()
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
assert frame_count == 8
assert frame_width == 48
assert frame_height == 64
cap.release()
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
@pytest.mark.benchmark
def test_generate_mm_item_image(video_dataset: RandomMultiModalDataset):
"""Test generating multimodal items for image configurations."""
# Test image item generation
image_config = (64, 48, 1) # height, width, num_frames=1
result = video_dataset.generate_mm_item(image_config)
# Check the result structure matches OpenAI API format
assert isinstance(result, dict)
assert result["type"] == "image_url"
assert "image_url" in result
assert "url" in result["image_url"]
# Check that the URL is a data URL with base64 encoded image
url = result["image_url"]["url"]
assert url.startswith("data:image/jpeg;base64,")
@pytest.mark.benchmark
def test_generate_mm_item_invalid_config(video_dataset: RandomMultiModalDataset):
"""Test error handling for invalid configurations."""
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
video_dataset.generate_mm_item((256, 256, 0))
@pytest.mark.benchmark
def test_sample_with_video_buckets(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with video bucket configurations."""
# Configure bucket with video probability > 0
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 5, "video": 3}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=5,
base_items_per_request=2,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 5
# Check that samples contain both images and videos
video_count = 0
image_count = 0
for sample in samples:
assert isinstance(sample, SampleRequest)
assert sample.multi_modal_data is not None
assert isinstance(sample.multi_modal_data, list)
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) == 2 # base_items_per_request
for item in mm_data:
if item["type"] == "video_url":
video_count += 1
# Verify video URL format
url = item["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
elif item["type"] == "image_url":
image_count += 1
# Verify image URL format
url = item["image_url"]["url"]
assert url.startswith("data:image/jpeg;base64,")
# Should have some videos due to 0.7 probability
assert video_count > 0
assert image_count > 0
@pytest.mark.benchmark
def test_sample_video_only_buckets(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with only video buckets."""
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 2}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 3
for sample in samples:
assert isinstance(sample, SampleRequest)
assert sample.multi_modal_data is not None
assert isinstance(sample.multi_modal_data, list)
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) == 1
item = mm_data[0]
assert item["type"] == "video_url"
url = item["video_url"]["url"]
assert url.startswith("data:video/mp4;base64,")
@pytest.mark.benchmark
def test_sample_respects_video_limits(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test that sampling respects video limits per prompt."""
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
# Set very low video limit
limit_mm_per_prompt = {"image": 0, "video": 1}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 3
for sample in samples:
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
assert len(mm_data) <= 1 # Should respect video limit
@pytest.mark.benchmark
def test_sample_mixed_buckets_with_zero_probability(
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
):
"""Test sampling with mixed buckets including zero probability entries."""
bucket_config = {
(64, 64, 1): 0.5, # Images
(64, 64, 8): 0.5, # Videos
(128, 128, 16): 0.0, # Zero probability videos (should be ignored)
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples = video_dataset.sample(
tokenizer=hf_tokenizer,
num_requests=4,
base_items_per_request=2,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples) == 4
# Should only see 64x64 videos, not 128x128 videos
for sample in samples:
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
for item in mm_data:
if item["type"] == "video_url":
# Decode video to verify dimensions
url = item["video_url"]["url"]
base64_data = url.split(",")[1]
video_bytes = base64.b64decode(base64_data)
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: # noqa
temp_path = temp_file.name
temp_file.write(video_bytes)
try:
cap = cv2.VideoCapture(temp_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
# Should be 64x64, not 128x128
assert frame_width == 64
assert frame_height == 64
finally:
if os.path.exists(temp_path):
os.unlink(temp_path)
@pytest.mark.benchmark
def test_sample_deterministic_with_videos(hf_tokenizer: PreTrainedTokenizerBase):
"""Test that sampling with videos is deterministic with same seed."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=123)
bucket_config = {
(64, 64, 1): 0.3, # Images
(64, 64, 8): 0.7, # Videos
}
limit_mm_per_prompt = {"image": 2, "video": 2}
samples1 = dataset1.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
samples2 = dataset2.sample(
tokenizer=hf_tokenizer,
num_requests=3,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
assert len(samples1) == len(samples2)
# Compare multimodal data
for s1, s2 in zip(samples1, samples2):
assert s1.multi_modal_data == s2.multi_modal_data
@pytest.mark.benchmark
def test_sample_different_seeds_produce_different_videos(
hf_tokenizer: PreTrainedTokenizerBase,
):
"""Test that different seeds produce different video content."""
dataset1 = RandomMultiModalDataset(random_seed=123)
dataset2 = RandomMultiModalDataset(random_seed=456)
bucket_config = {
(64, 64, 8): 1.0, # Only videos
}
limit_mm_per_prompt = {"image": 0, "video": 1}
samples1 = dataset1.sample(
tokenizer=hf_tokenizer,
num_requests=2,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
samples2 = dataset2.sample(
tokenizer=hf_tokenizer,
num_requests=2,
base_items_per_request=1,
num_mm_items_range_ratio=0.0,
limit_mm_per_prompt=limit_mm_per_prompt,
bucket_config=bucket_config,
input_len=20,
output_len=5,
)
# Video content should be different
for s1, s2 in zip(samples1, samples2):
mm_data1 = cast(list[dict[str, Any]], s1.multi_modal_data)
mm_data2 = cast(list[dict[str, Any]], s2.multi_modal_data)
assert len(mm_data1) == len(mm_data2) == 1
url1 = mm_data1[0]["video_url"]["url"]
url2 = mm_data2[0]["video_url"]["url"]
assert url1 != url2 # Different video content

View File

@ -27,8 +27,10 @@ from copy import deepcopy
from dataclasses import dataclass
from functools import cache
from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import Any, cast
import cv2
import numpy as np
from PIL import Image
from transformers import PreTrainedTokenizerBase
@ -498,9 +500,13 @@ class RandomDataset(BenchmarkDataset):
num_requests, range_ratio, input_len, output_len, tokenizer
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
prohibited_tokens = tokenizer.all_special_ids
all_tokens = np.arange(vocab_size)
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
# Generate prefix once
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
requests = []
token_mismatch_total = 0
@ -513,6 +519,7 @@ class RandomDataset(BenchmarkDataset):
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
allowed_tokens=allowed_tokens,
)
token_mismatch_total += token_mismatch
requests.append(
@ -553,13 +560,17 @@ class RandomDataset(BenchmarkDataset):
return requests
def get_prefix(
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
self,
allowed_tokens: np.ndarray,
prefix_len: int,
) -> list[int]:
"""
Get the prefix for the dataset.
"""
return (
self._rng.integers(0, tokenizer.vocab_size, size=prefix_len).tolist()
allowed_tokens[
self._rng.integers(0, len(allowed_tokens), size=prefix_len)
].tolist()
if prefix_len > 0
else []
)
@ -623,6 +634,7 @@ class RandomDataset(BenchmarkDataset):
input_len: int,
offset: int,
index: int,
allowed_tokens: np.ndarray,
) -> tuple[str, int, int]:
"""
Returns (prompt, total_input_len).
@ -636,8 +648,11 @@ class RandomDataset(BenchmarkDataset):
To avoid uncontrolled change of the prompt length,
the encoded sequence is truncated before being decoded again.
"""
# Build the inner sequence by sampling sequentially from the vocab
inner_seq = ((offset + index + np.arange(input_len)) % vocab_size).tolist()
# Build the inner sequence by sampling
# sequentially from the allowed tokens
inner_seq = allowed_tokens[
(offset + index + np.arange(input_len)) % len(allowed_tokens)
].tolist()
token_sequence = prefix_token_ids + inner_seq
# Decode, then re-encode and truncate to preserve token count invariants
@ -772,7 +787,7 @@ class RandomMultiModalDataset(RandomDataset):
Status:
- Images: supported via synthetic RGB data.
- Video: not yet supported (TODO: implement video generation method).
- Video: supported via synthetic RGB data.
- Audio: not yet supported.
Sampling overview:
@ -782,7 +797,7 @@ class RandomMultiModalDataset(RandomDataset):
The maximum is further clamped to the sum of per-modality limits.
2) Each items modality and shape is sampled from `bucket_config`, a dict
mapping (height, width, num_frames) probability. We treat
`num_frames`=1 as image and and `num_frames` > 1 as video.
`num_frames`=1 as image and `num_frames` > 1 as video.
Entries with zero probability are removed and the rest are renormalized
to sum to 1.
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
@ -797,8 +812,7 @@ class RandomMultiModalDataset(RandomDataset):
"""
IS_MULTIMODAL = True
# NOTE: video sampling is WIP. Setting it to 0.
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0}
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 1}
DEFAULT_BASE_ITEMS_PER_REQUEST = 1
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
@ -828,12 +842,47 @@ class RandomMultiModalDataset(RandomDataset):
)
return Image.fromarray(random_pixels)
def generate_synthetic_video(self, width: int, height: int, num_frames: int) -> Any:
def generate_synthetic_video(
self, width: int, height: int, num_frames: int
) -> dict:
"""Generate synthetic video with random values.
TODO: Finish this method.
Creates a video with random pixel values, encodes it to MP4 format,
and returns the content as bytes.
"""
raise NotImplementedError("Video sampling is WIP.")
random_pixels = self._rng.integers(
0,
256,
(num_frames, height, width, 3),
dtype=np.uint8,
)
# Create a temporary video file in memory
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
fps = 30 # frames per second
with NamedTemporaryFile(suffix=".mp4", delete_on_close=False) as temp_file:
temp_path = temp_file.name
# Create video writer
video_writer = cv2.VideoWriter(
temp_path, fourcc=fourcc, fps=fps, frameSize=(width, height)
)
if not video_writer.isOpened():
raise RuntimeError("Failed to create video writer")
for frame in random_pixels:
video_writer.write(frame)
video_writer.release()
temp_file.close()
# Read the video file content
with open(temp_path, "rb") as f:
video_content = f.read()
return {"bytes": video_content}
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
"""Map the configuration to the modality."""
@ -1044,16 +1093,6 @@ class RandomMultiModalDataset(RandomDataset):
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
**kwargs,
) -> list[SampleRequest]:
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
# and probability is non-zero.
if any(
self.map_config_to_modality(cfg) == "video" and p > 0
for cfg, p in bucket_config.items()
):
raise NotImplementedError(
"Video sampling not implemented; set its probability to 0."
)
# Get the sampling parameters for the dataset
input_lens, output_lens, offsets = self.get_sampling_params(
num_requests, range_ratio, input_len, output_len, tokenizer
@ -1071,9 +1110,24 @@ class RandomMultiModalDataset(RandomDataset):
bucket_config,
)
# Generate prefix once
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
# Can't use tokenizer.all_special_ids since
# it returns ONLY ids from special_tokens_map.json
# We want to exclude placeholder tokens and all
# tokens that indicate start/end of image as it
# may break prompt replacement logic.
prohibited_tokens = list(
tok_id
for tok_id, token in tokenizer.added_tokens_decoder.items()
if token.special
)
all_tokens = np.arange(vocab_size)
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
logger.debug(
"Sampling from %d out of %d (vocab size)", len(allowed_tokens), vocab_size
)
# Generate prefix once
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
# Add synthetic multimodal items to each request
mm_requests = []
token_mismatch_total = 0
@ -1086,6 +1140,7 @@ class RandomMultiModalDataset(RandomDataset):
input_len=int(input_lens[i]),
offset=int(offsets[i]),
index=i,
allowed_tokens=allowed_tokens,
)
token_mismatch_total += token_mismatch
# Get multimodal item iterator for a given request