mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:04:59 +08:00
Feature/video support in random mm dataset (#25963)
Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenia@nvidia.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
1a33aacf82
commit
5e72216d17
@ -359,3 +359,126 @@ def test_random_mm_bucket_config_not_mutated(
|
|||||||
assert len(mm_data) >= 1
|
assert len(mm_data) >= 1
|
||||||
for it in mm_data:
|
for it in mm_data:
|
||||||
assert it.get("type") == "image_url"
|
assert it.get("type") == "image_url"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_video_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||||
|
"""Test video sampling functionality in RandomMultiModalDataset."""
|
||||||
|
ds = RandomMultiModalDataset(random_seed=42)
|
||||||
|
|
||||||
|
# Test with video bucket configuration
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 1): 0.3, # Images
|
||||||
|
(64, 64, 8): 0.7, # Videos
|
||||||
|
}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {"image": 2, "video": 2}
|
||||||
|
|
||||||
|
samples = _collect_mm_samples(
|
||||||
|
ds,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=5,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(samples) == 5
|
||||||
|
|
||||||
|
# Check that we have both images and videos
|
||||||
|
video_count = 0
|
||||||
|
image_count = 0
|
||||||
|
|
||||||
|
for s in samples:
|
||||||
|
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||||
|
assert len(mm_data) == 1
|
||||||
|
|
||||||
|
item = mm_data[0]
|
||||||
|
if item.get("type") == "video_url":
|
||||||
|
video_count += 1
|
||||||
|
# Verify video URL format
|
||||||
|
url = item.get("video_url", {}).get("url", "")
|
||||||
|
assert url.startswith("data:video/mp4;base64,")
|
||||||
|
elif item.get("type") == "image_url":
|
||||||
|
image_count += 1
|
||||||
|
# Verify image URL format
|
||||||
|
url = item.get("image_url", {}).get("url", "")
|
||||||
|
assert url.startswith("data:image/jpeg;base64,")
|
||||||
|
|
||||||
|
# Should have some videos due to 0.7 probability
|
||||||
|
assert video_count > 0
|
||||||
|
assert image_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_video_only_sampling(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||||
|
"""Test sampling with only video buckets."""
|
||||||
|
ds = RandomMultiModalDataset(random_seed=42)
|
||||||
|
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 8): 1.0, # Only videos
|
||||||
|
}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {"image": 0, "video": 1}
|
||||||
|
|
||||||
|
samples = _collect_mm_samples(
|
||||||
|
ds,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=3,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(samples) == 3
|
||||||
|
|
||||||
|
for s in samples:
|
||||||
|
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||||
|
assert len(mm_data) == 1
|
||||||
|
|
||||||
|
item = mm_data[0]
|
||||||
|
assert item.get("type") == "video_url"
|
||||||
|
url = item.get("video_url", {}).get("url", "")
|
||||||
|
assert url.startswith("data:video/mp4;base64,")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_video_deterministic_sampling(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> None:
|
||||||
|
"""Test that video sampling is deterministic with same seed."""
|
||||||
|
seed = 123
|
||||||
|
ds_a = RandomMultiModalDataset(random_seed=seed)
|
||||||
|
ds_b = RandomMultiModalDataset(random_seed=seed)
|
||||||
|
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 8): 1.0, # Only videos
|
||||||
|
}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {"image": 0, "video": 1}
|
||||||
|
|
||||||
|
a = _collect_mm_samples(
|
||||||
|
ds_a,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=3,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
b = _collect_mm_samples(
|
||||||
|
ds_b,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=3,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||||
|
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||||
|
assert fa == fb
|
||||||
|
|||||||
398
tests/benchmarks/test_random_multimodal_dataset_video.py
Normal file
398
tests/benchmarks/test_random_multimodal_dataset_video.py
Normal file
@ -0,0 +1,398 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.benchmarks.datasets import RandomMultiModalDataset, SampleRequest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def hf_tokenizer() -> PreTrainedTokenizerBase:
|
||||||
|
"""Use a small, commonly available tokenizer."""
|
||||||
|
return AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def video_dataset() -> RandomMultiModalDataset:
|
||||||
|
"""Create a RandomMultiModalDataset instance for testing."""
|
||||||
|
return RandomMultiModalDataset(random_seed=42)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_generate_synthetic_video_different_seeds():
|
||||||
|
"""Test that different seeds produce different videos."""
|
||||||
|
dataset1 = RandomMultiModalDataset(random_seed=123)
|
||||||
|
dataset2 = RandomMultiModalDataset(random_seed=456)
|
||||||
|
|
||||||
|
width, height, num_frames = 64, 48, 8
|
||||||
|
|
||||||
|
video1 = dataset1.generate_synthetic_video(width, height, num_frames)
|
||||||
|
video2 = dataset2.generate_synthetic_video(width, height, num_frames)
|
||||||
|
|
||||||
|
# Videos should be different due to different seeds
|
||||||
|
assert video1["bytes"] != video2["bytes"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_map_config_to_modality(video_dataset: RandomMultiModalDataset):
|
||||||
|
"""Test modality mapping for different configurations."""
|
||||||
|
# Test image configuration (num_frames = 1)
|
||||||
|
assert video_dataset.map_config_to_modality((256, 256, 1)) == "image"
|
||||||
|
assert video_dataset.map_config_to_modality((720, 1280, 1)) == "image"
|
||||||
|
|
||||||
|
# Test video configurations (num_frames > 1)
|
||||||
|
assert video_dataset.map_config_to_modality((256, 256, 8)) == "video"
|
||||||
|
assert video_dataset.map_config_to_modality((720, 1280, 16)) == "video"
|
||||||
|
assert video_dataset.map_config_to_modality((64, 64, 32)) == "video"
|
||||||
|
|
||||||
|
# Test invalid configurations
|
||||||
|
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
|
||||||
|
video_dataset.map_config_to_modality((256, 256, 0))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
|
||||||
|
video_dataset.map_config_to_modality((256, 256, -1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_generate_mm_item_video(video_dataset: RandomMultiModalDataset):
|
||||||
|
"""Test generating multimodal items for video configurations."""
|
||||||
|
# Test video item generation
|
||||||
|
video_config = (64, 48, 8) # height, width, num_frames
|
||||||
|
result = video_dataset.generate_mm_item(video_config)
|
||||||
|
|
||||||
|
# Check the result structure matches OpenAI API format
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["type"] == "video_url"
|
||||||
|
assert "video_url" in result
|
||||||
|
assert "url" in result["video_url"]
|
||||||
|
|
||||||
|
# Check that the URL is a data URL with base64 encoded video
|
||||||
|
url = result["video_url"]["url"]
|
||||||
|
assert url.startswith("data:video/mp4;base64,")
|
||||||
|
|
||||||
|
# Decode and verify the video content
|
||||||
|
base64_data = url.split(",")[1]
|
||||||
|
video_bytes = base64.b64decode(base64_data)
|
||||||
|
assert len(video_bytes) > 0
|
||||||
|
|
||||||
|
# Verify the video can be decoded
|
||||||
|
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
temp_file.write(video_bytes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cap = cv2.VideoCapture(temp_path)
|
||||||
|
assert cap.isOpened()
|
||||||
|
|
||||||
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
|
||||||
|
assert frame_count == 8
|
||||||
|
assert frame_width == 48
|
||||||
|
assert frame_height == 64
|
||||||
|
|
||||||
|
cap.release()
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_generate_mm_item_image(video_dataset: RandomMultiModalDataset):
|
||||||
|
"""Test generating multimodal items for image configurations."""
|
||||||
|
# Test image item generation
|
||||||
|
image_config = (64, 48, 1) # height, width, num_frames=1
|
||||||
|
result = video_dataset.generate_mm_item(image_config)
|
||||||
|
|
||||||
|
# Check the result structure matches OpenAI API format
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["type"] == "image_url"
|
||||||
|
assert "image_url" in result
|
||||||
|
assert "url" in result["image_url"]
|
||||||
|
|
||||||
|
# Check that the URL is a data URL with base64 encoded image
|
||||||
|
url = result["image_url"]["url"]
|
||||||
|
assert url.startswith("data:image/jpeg;base64,")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_generate_mm_item_invalid_config(video_dataset: RandomMultiModalDataset):
|
||||||
|
"""Test error handling for invalid configurations."""
|
||||||
|
with pytest.raises(ValueError, match="Invalid multimodal item configuration"):
|
||||||
|
video_dataset.generate_mm_item((256, 256, 0))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_sample_with_video_buckets(
|
||||||
|
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
|
||||||
|
):
|
||||||
|
"""Test sampling with video bucket configurations."""
|
||||||
|
# Configure bucket with video probability > 0
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 1): 0.3, # Images
|
||||||
|
(64, 64, 8): 0.7, # Videos
|
||||||
|
}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {"image": 5, "video": 3}
|
||||||
|
|
||||||
|
samples = video_dataset.sample(
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
num_requests=5,
|
||||||
|
base_items_per_request=2,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
input_len=20,
|
||||||
|
output_len=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(samples) == 5
|
||||||
|
|
||||||
|
# Check that samples contain both images and videos
|
||||||
|
video_count = 0
|
||||||
|
image_count = 0
|
||||||
|
|
||||||
|
for sample in samples:
|
||||||
|
assert isinstance(sample, SampleRequest)
|
||||||
|
assert sample.multi_modal_data is not None
|
||||||
|
assert isinstance(sample.multi_modal_data, list)
|
||||||
|
|
||||||
|
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
|
||||||
|
assert len(mm_data) == 2 # base_items_per_request
|
||||||
|
|
||||||
|
for item in mm_data:
|
||||||
|
if item["type"] == "video_url":
|
||||||
|
video_count += 1
|
||||||
|
# Verify video URL format
|
||||||
|
url = item["video_url"]["url"]
|
||||||
|
assert url.startswith("data:video/mp4;base64,")
|
||||||
|
elif item["type"] == "image_url":
|
||||||
|
image_count += 1
|
||||||
|
# Verify image URL format
|
||||||
|
url = item["image_url"]["url"]
|
||||||
|
assert url.startswith("data:image/jpeg;base64,")
|
||||||
|
|
||||||
|
# Should have some videos due to 0.7 probability
|
||||||
|
assert video_count > 0
|
||||||
|
assert image_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_sample_video_only_buckets(
|
||||||
|
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
|
||||||
|
):
|
||||||
|
"""Test sampling with only video buckets."""
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 8): 1.0, # Only videos
|
||||||
|
}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {"image": 0, "video": 2}
|
||||||
|
|
||||||
|
samples = video_dataset.sample(
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
num_requests=3,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
input_len=20,
|
||||||
|
output_len=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(samples) == 3
|
||||||
|
|
||||||
|
for sample in samples:
|
||||||
|
assert isinstance(sample, SampleRequest)
|
||||||
|
assert sample.multi_modal_data is not None
|
||||||
|
assert isinstance(sample.multi_modal_data, list)
|
||||||
|
|
||||||
|
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
|
||||||
|
assert len(mm_data) == 1
|
||||||
|
|
||||||
|
item = mm_data[0]
|
||||||
|
assert item["type"] == "video_url"
|
||||||
|
url = item["video_url"]["url"]
|
||||||
|
assert url.startswith("data:video/mp4;base64,")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_sample_respects_video_limits(
|
||||||
|
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
|
||||||
|
):
|
||||||
|
"""Test that sampling respects video limits per prompt."""
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 8): 1.0, # Only videos
|
||||||
|
}
|
||||||
|
|
||||||
|
# Set very low video limit
|
||||||
|
limit_mm_per_prompt = {"image": 0, "video": 1}
|
||||||
|
|
||||||
|
samples = video_dataset.sample(
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
num_requests=3,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
input_len=20,
|
||||||
|
output_len=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(samples) == 3
|
||||||
|
|
||||||
|
for sample in samples:
|
||||||
|
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
|
||||||
|
assert len(mm_data) <= 1 # Should respect video limit
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_sample_mixed_buckets_with_zero_probability(
|
||||||
|
video_dataset: RandomMultiModalDataset, hf_tokenizer: PreTrainedTokenizerBase
|
||||||
|
):
|
||||||
|
"""Test sampling with mixed buckets including zero probability entries."""
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 1): 0.5, # Images
|
||||||
|
(64, 64, 8): 0.5, # Videos
|
||||||
|
(128, 128, 16): 0.0, # Zero probability videos (should be ignored)
|
||||||
|
}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {"image": 2, "video": 2}
|
||||||
|
|
||||||
|
samples = video_dataset.sample(
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
num_requests=4,
|
||||||
|
base_items_per_request=2,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
input_len=20,
|
||||||
|
output_len=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(samples) == 4
|
||||||
|
|
||||||
|
# Should only see 64x64 videos, not 128x128 videos
|
||||||
|
for sample in samples:
|
||||||
|
mm_data = cast(list[dict[str, Any]], sample.multi_modal_data)
|
||||||
|
for item in mm_data:
|
||||||
|
if item["type"] == "video_url":
|
||||||
|
# Decode video to verify dimensions
|
||||||
|
url = item["video_url"]["url"]
|
||||||
|
base64_data = url.split(",")[1]
|
||||||
|
video_bytes = base64.b64decode(base64_data)
|
||||||
|
|
||||||
|
with NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file: # noqa
|
||||||
|
temp_path = temp_file.name
|
||||||
|
temp_file.write(video_bytes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
cap = cv2.VideoCapture(temp_path)
|
||||||
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
cap.release()
|
||||||
|
|
||||||
|
# Should be 64x64, not 128x128
|
||||||
|
assert frame_width == 64
|
||||||
|
assert frame_height == 64
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_sample_deterministic_with_videos(hf_tokenizer: PreTrainedTokenizerBase):
|
||||||
|
"""Test that sampling with videos is deterministic with same seed."""
|
||||||
|
dataset1 = RandomMultiModalDataset(random_seed=123)
|
||||||
|
dataset2 = RandomMultiModalDataset(random_seed=123)
|
||||||
|
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 1): 0.3, # Images
|
||||||
|
(64, 64, 8): 0.7, # Videos
|
||||||
|
}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {"image": 2, "video": 2}
|
||||||
|
|
||||||
|
samples1 = dataset1.sample(
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
num_requests=3,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
input_len=20,
|
||||||
|
output_len=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
samples2 = dataset2.sample(
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
num_requests=3,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
input_len=20,
|
||||||
|
output_len=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(samples1) == len(samples2)
|
||||||
|
|
||||||
|
# Compare multimodal data
|
||||||
|
for s1, s2 in zip(samples1, samples2):
|
||||||
|
assert s1.multi_modal_data == s2.multi_modal_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_sample_different_seeds_produce_different_videos(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
|
):
|
||||||
|
"""Test that different seeds produce different video content."""
|
||||||
|
dataset1 = RandomMultiModalDataset(random_seed=123)
|
||||||
|
dataset2 = RandomMultiModalDataset(random_seed=456)
|
||||||
|
|
||||||
|
bucket_config = {
|
||||||
|
(64, 64, 8): 1.0, # Only videos
|
||||||
|
}
|
||||||
|
|
||||||
|
limit_mm_per_prompt = {"image": 0, "video": 1}
|
||||||
|
|
||||||
|
samples1 = dataset1.sample(
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
num_requests=2,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
input_len=20,
|
||||||
|
output_len=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
samples2 = dataset2.sample(
|
||||||
|
tokenizer=hf_tokenizer,
|
||||||
|
num_requests=2,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
input_len=20,
|
||||||
|
output_len=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Video content should be different
|
||||||
|
for s1, s2 in zip(samples1, samples2):
|
||||||
|
mm_data1 = cast(list[dict[str, Any]], s1.multi_modal_data)
|
||||||
|
mm_data2 = cast(list[dict[str, Any]], s2.multi_modal_data)
|
||||||
|
|
||||||
|
assert len(mm_data1) == len(mm_data2) == 1
|
||||||
|
|
||||||
|
url1 = mm_data1[0]["video_url"]["url"]
|
||||||
|
url2 = mm_data2[0]["video_url"]["url"]
|
||||||
|
|
||||||
|
assert url1 != url2 # Different video content
|
||||||
@ -27,8 +27,10 @@ from copy import deepcopy
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
@ -498,9 +500,13 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
num_requests, range_ratio, input_len, output_len, tokenizer
|
num_requests, range_ratio, input_len, output_len, tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate prefix once
|
|
||||||
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
|
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
|
prohibited_tokens = tokenizer.all_special_ids
|
||||||
|
all_tokens = np.arange(vocab_size)
|
||||||
|
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
|
||||||
|
|
||||||
|
# Generate prefix once
|
||||||
|
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
token_mismatch_total = 0
|
token_mismatch_total = 0
|
||||||
@ -513,6 +519,7 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
input_len=int(input_lens[i]),
|
input_len=int(input_lens[i]),
|
||||||
offset=int(offsets[i]),
|
offset=int(offsets[i]),
|
||||||
index=i,
|
index=i,
|
||||||
|
allowed_tokens=allowed_tokens,
|
||||||
)
|
)
|
||||||
token_mismatch_total += token_mismatch
|
token_mismatch_total += token_mismatch
|
||||||
requests.append(
|
requests.append(
|
||||||
@ -553,13 +560,17 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
return requests
|
return requests
|
||||||
|
|
||||||
def get_prefix(
|
def get_prefix(
|
||||||
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
|
self,
|
||||||
|
allowed_tokens: np.ndarray,
|
||||||
|
prefix_len: int,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Get the prefix for the dataset.
|
Get the prefix for the dataset.
|
||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
self._rng.integers(0, tokenizer.vocab_size, size=prefix_len).tolist()
|
allowed_tokens[
|
||||||
|
self._rng.integers(0, len(allowed_tokens), size=prefix_len)
|
||||||
|
].tolist()
|
||||||
if prefix_len > 0
|
if prefix_len > 0
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
@ -623,6 +634,7 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
input_len: int,
|
input_len: int,
|
||||||
offset: int,
|
offset: int,
|
||||||
index: int,
|
index: int,
|
||||||
|
allowed_tokens: np.ndarray,
|
||||||
) -> tuple[str, int, int]:
|
) -> tuple[str, int, int]:
|
||||||
"""
|
"""
|
||||||
Returns (prompt, total_input_len).
|
Returns (prompt, total_input_len).
|
||||||
@ -636,8 +648,11 @@ class RandomDataset(BenchmarkDataset):
|
|||||||
To avoid uncontrolled change of the prompt length,
|
To avoid uncontrolled change of the prompt length,
|
||||||
the encoded sequence is truncated before being decoded again.
|
the encoded sequence is truncated before being decoded again.
|
||||||
"""
|
"""
|
||||||
# Build the inner sequence by sampling sequentially from the vocab
|
# Build the inner sequence by sampling
|
||||||
inner_seq = ((offset + index + np.arange(input_len)) % vocab_size).tolist()
|
# sequentially from the allowed tokens
|
||||||
|
inner_seq = allowed_tokens[
|
||||||
|
(offset + index + np.arange(input_len)) % len(allowed_tokens)
|
||||||
|
].tolist()
|
||||||
token_sequence = prefix_token_ids + inner_seq
|
token_sequence = prefix_token_ids + inner_seq
|
||||||
|
|
||||||
# Decode, then re-encode and truncate to preserve token count invariants
|
# Decode, then re-encode and truncate to preserve token count invariants
|
||||||
@ -772,7 +787,7 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
|
|
||||||
Status:
|
Status:
|
||||||
- Images: supported via synthetic RGB data.
|
- Images: supported via synthetic RGB data.
|
||||||
- Video: not yet supported (TODO: implement video generation method).
|
- Video: supported via synthetic RGB data.
|
||||||
- Audio: not yet supported.
|
- Audio: not yet supported.
|
||||||
|
|
||||||
Sampling overview:
|
Sampling overview:
|
||||||
@ -782,7 +797,7 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
The maximum is further clamped to the sum of per-modality limits.
|
The maximum is further clamped to the sum of per-modality limits.
|
||||||
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
|
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
|
||||||
mapping (height, width, num_frames) → probability. We treat
|
mapping (height, width, num_frames) → probability. We treat
|
||||||
`num_frames`=1 as image and and `num_frames` > 1 as video.
|
`num_frames`=1 as image and `num_frames` > 1 as video.
|
||||||
Entries with zero probability are removed and the rest are renormalized
|
Entries with zero probability are removed and the rest are renormalized
|
||||||
to sum to 1.
|
to sum to 1.
|
||||||
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
|
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
|
||||||
@ -797,8 +812,7 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
IS_MULTIMODAL = True
|
IS_MULTIMODAL = True
|
||||||
# NOTE: video sampling is WIP. Setting it to 0.
|
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 1}
|
||||||
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0}
|
|
||||||
|
|
||||||
DEFAULT_BASE_ITEMS_PER_REQUEST = 1
|
DEFAULT_BASE_ITEMS_PER_REQUEST = 1
|
||||||
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
|
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
|
||||||
@ -828,12 +842,47 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
)
|
)
|
||||||
return Image.fromarray(random_pixels)
|
return Image.fromarray(random_pixels)
|
||||||
|
|
||||||
def generate_synthetic_video(self, width: int, height: int, num_frames: int) -> Any:
|
def generate_synthetic_video(
|
||||||
|
self, width: int, height: int, num_frames: int
|
||||||
|
) -> dict:
|
||||||
"""Generate synthetic video with random values.
|
"""Generate synthetic video with random values.
|
||||||
|
|
||||||
TODO: Finish this method.
|
Creates a video with random pixel values, encodes it to MP4 format,
|
||||||
|
and returns the content as bytes.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("Video sampling is WIP.")
|
random_pixels = self._rng.integers(
|
||||||
|
0,
|
||||||
|
256,
|
||||||
|
(num_frames, height, width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a temporary video file in memory
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||||
|
fps = 30 # frames per second
|
||||||
|
|
||||||
|
with NamedTemporaryFile(suffix=".mp4", delete_on_close=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
# Create video writer
|
||||||
|
video_writer = cv2.VideoWriter(
|
||||||
|
temp_path, fourcc=fourcc, fps=fps, frameSize=(width, height)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not video_writer.isOpened():
|
||||||
|
raise RuntimeError("Failed to create video writer")
|
||||||
|
|
||||||
|
for frame in random_pixels:
|
||||||
|
video_writer.write(frame)
|
||||||
|
|
||||||
|
video_writer.release()
|
||||||
|
temp_file.close()
|
||||||
|
|
||||||
|
# Read the video file content
|
||||||
|
with open(temp_path, "rb") as f:
|
||||||
|
video_content = f.read()
|
||||||
|
|
||||||
|
return {"bytes": video_content}
|
||||||
|
|
||||||
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
|
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
|
||||||
"""Map the configuration to the modality."""
|
"""Map the configuration to the modality."""
|
||||||
@ -1044,16 +1093,6 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
|
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SampleRequest]:
|
) -> list[SampleRequest]:
|
||||||
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
|
|
||||||
# and probability is non-zero.
|
|
||||||
if any(
|
|
||||||
self.map_config_to_modality(cfg) == "video" and p > 0
|
|
||||||
for cfg, p in bucket_config.items()
|
|
||||||
):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Video sampling not implemented; set its probability to 0."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the sampling parameters for the dataset
|
# Get the sampling parameters for the dataset
|
||||||
input_lens, output_lens, offsets = self.get_sampling_params(
|
input_lens, output_lens, offsets = self.get_sampling_params(
|
||||||
num_requests, range_ratio, input_len, output_len, tokenizer
|
num_requests, range_ratio, input_len, output_len, tokenizer
|
||||||
@ -1071,9 +1110,24 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
bucket_config,
|
bucket_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate prefix once
|
|
||||||
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
|
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
|
# Can't use tokenizer.all_special_ids since
|
||||||
|
# it returns ONLY ids from special_tokens_map.json
|
||||||
|
# We want to exclude placeholder tokens and all
|
||||||
|
# tokens that indicate start/end of image as it
|
||||||
|
# may break prompt replacement logic.
|
||||||
|
prohibited_tokens = list(
|
||||||
|
tok_id
|
||||||
|
for tok_id, token in tokenizer.added_tokens_decoder.items()
|
||||||
|
if token.special
|
||||||
|
)
|
||||||
|
all_tokens = np.arange(vocab_size)
|
||||||
|
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
|
||||||
|
logger.debug(
|
||||||
|
"Sampling from %d out of %d (vocab size)", len(allowed_tokens), vocab_size
|
||||||
|
)
|
||||||
|
# Generate prefix once
|
||||||
|
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len)
|
||||||
# Add synthetic multimodal items to each request
|
# Add synthetic multimodal items to each request
|
||||||
mm_requests = []
|
mm_requests = []
|
||||||
token_mismatch_total = 0
|
token_mismatch_total = 0
|
||||||
@ -1086,6 +1140,7 @@ class RandomMultiModalDataset(RandomDataset):
|
|||||||
input_len=int(input_lens[i]),
|
input_len=int(input_lens[i]),
|
||||||
offset=int(offsets[i]),
|
offset=int(offsets[i]),
|
||||||
index=i,
|
index=i,
|
||||||
|
allowed_tokens=allowed_tokens,
|
||||||
)
|
)
|
||||||
token_mismatch_total += token_mismatch
|
token_mismatch_total += token_mismatch
|
||||||
# Get multimodal item iterator for a given request
|
# Get multimodal item iterator for a given request
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user