diff --git a/benchmarks/README.md b/benchmarks/README.md index 176b40212978f..a2dd5bb58325c 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -59,6 +59,12 @@ become available. ✅ synthetic + + RandomMultiModal (Image/Video) + 🟡 + 🚧 + synthetic + Prefix Repetition ✅ @@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \ --endpoint /v1/chat/completion ``` +### Synthetic Random Images (random-mm) + +Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. + +Notes: + +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. +- Video sampling is not yet implemented. + +Start the server (example): + +```bash +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --mm-processor-kwargs max_pixels=1003520 +``` + +Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. + +Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name random-mm \ + --num-prompts 100 \ + --max-concurrency 10 \ + --random-prefix-len 25 \ + --random-input-len 300 \ + --random-output-len 40 \ + --random-range-ratio 0.2 \ + --random-mm-base-items-per-request 2 \ + --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ + --request-rate inf \ + --ignore-eos \ + --seed 42 +``` + +The number of items per request can be controlled by passing multiple image buckets: + +```bash + --random-mm-base-items-per-request 2 \ + --random-mm-num-mm-items-range-ratio 0.5 \ + --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ + --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ +``` + +Flags specific to `random-mm`: + +- `--random-mm-base-items-per-request`: base number of multimodal items per request. +- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. +- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. +- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). + +Behavioral notes: + +- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. + +How sampling works: + +- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. +- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. +- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. +This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. +- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. + diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py new file mode 100644 index 0000000000000..26cae369cdd5d --- /dev/null +++ b/tests/benchmarks/test_random_dataset.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +from typing import Any, NamedTuple, Optional, cast + +import numpy as np +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset, + SampleRequest) + + +@pytest.fixture(scope="session") +def hf_tokenizer() -> PreTrainedTokenizerBase: + # Use a small, commonly available tokenizer + return AutoTokenizer.from_pretrained("gpt2") + + +class Params(NamedTuple): + num_requests: int + prefix_len: int + range_ratio: float + input_len: int + output_len: int + + +@pytest.fixture(scope="session") +def random_dataset_params() -> Params: + return Params(num_requests=16, + prefix_len=7, + range_ratio=0.3, + input_len=50, + output_len=20) + + +def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: + """Project a SampleRequest into a comparable tuple.""" + return (req.prompt, req.prompt_len, req.expected_output_len) + + +def _collect_samples(dataset: RandomDataset, + tokenizer: PreTrainedTokenizerBase, + num_requests: int = 16, + prefix_len: int = 7, + range_ratio: float = 0.3, + input_len: int = 50, + output_len: int = 20) -> list[tuple[str, int, int]]: + samples = dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + ) + return [_fingerprint_sample(s) for s in samples] + + +@pytest.mark.benchmark +def test_random_dataset_same_seed( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Same seed should yield identical outputs, even if global RNGs change. + + This guards against accidental reliance on Python's random or np.random + in RandomDataset after moving to numpy.default_rng. + """ + p = random_dataset_params + common_seed = 123 + dataset_a = RandomDataset(random_seed=common_seed) + dataset_b = RandomDataset(random_seed=common_seed) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + # Perturb global RNG state to ensure isolation + random.seed(999) + _ = [random.random() for _ in range(100)] + np.random.seed(888) + _ = [np.random.random() for _ in range(100)] + + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a == b + +@pytest.mark.benchmark +def test_random_dataset_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Different seeds should change outputs with overwhelming likelihood.""" + p = random_dataset_params + seed_a = 0 + dataset_a = RandomDataset(random_seed=seed_a) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + seed_b = 999 + dataset_b = RandomDataset(random_seed=seed_b) + # Perturb global RNG with same seed as dataset_a to ensure isolation + random.seed(seed_a) + np.random.seed(seed_a) + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a != b + + +# ----------------------------- +# RandomMultiModalDataset tests +# ----------------------------- + +def _mm_fingerprint_sample( + req: SampleRequest, +) -> tuple[str, int, int, int, list[str]]: + """Create a compact fingerprint for multimodal samples. + + Includes: + - prompt string + - prompt_len + - expected_output_len + - count of multimodal items + - per-item type and URL prefix (e.g., 'data:image/jpeg;base64,') + """ + items = req.multi_modal_data or [] + item_prefixes: list[str] = [] + for it in items: + if isinstance(it, dict) and it.get("type") == "image_url": + url = it.get("image_url", {}).get("url", "") + # Only keep a short identifying prefix to avoid huge strings + item_prefixes.append(f"image:{url[:22]}") + elif isinstance(it, dict) and it.get("type") == "video_url": + url = it.get("video_url", {}).get("url", "") + item_prefixes.append(f"video:{url[:22]}") + else: + item_prefixes.append("unknown:") + return (req.prompt, req.prompt_len, req.expected_output_len, len(items), + item_prefixes) + + +def _collect_mm_samples( + dataset: RandomMultiModalDataset, + tokenizer: PreTrainedTokenizerBase, + *, + num_requests: int = 8, + prefix_len: int = 3, + range_ratio: float = 0.0, + input_len: int = 20, + output_len: int = 5, + base_items_per_request: int = 2, + num_mm_items_range_ratio: float = 0.0, + limit_mm_per_prompt: Optional[dict[str, int]] = None, + bucket_config: Optional[dict[tuple[int, int, int], float]] = None, + enable_multimodal_chat: bool = False, +) -> list[SampleRequest]: + if limit_mm_per_prompt is None: + limit_mm_per_prompt = {"image": 5, "video": 0} + if bucket_config is None: + bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5} + return dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + base_items_per_request=base_items_per_request, + num_mm_items_range_ratio=num_mm_items_range_ratio, + limit_mm_per_prompt=limit_mm_per_prompt, + bucket_config=bucket_config, + enable_multimodal_chat=enable_multimodal_chat, + ) + + +@pytest.mark.benchmark +def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None: + seed = 42 + ds_a = RandomMultiModalDataset(random_seed=seed) + ds_b = RandomMultiModalDataset(random_seed=seed) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa == fb + + +@pytest.mark.benchmark +def test_random_mm_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds_a = RandomMultiModalDataset(random_seed=0) + ds_b = RandomMultiModalDataset(random_seed=999) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa != fb + +@pytest.mark.benchmark +def test_random_mm_respects_limits( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Requesting 3 items with a per-prompt limit of 1 should error per current + # design (dataset refuses to silently clamp below the requested baseline). + with pytest.raises(ValueError): + _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=12, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + + +@pytest.mark.benchmark +def test_random_mm_zero_prob_entries_are_removed( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Second bucket has zero probability and should be ignored after + # normalization + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=6, + base_items_per_request=2, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 10, "video": 0}, + bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0}, + ) + for s in samples: + assert isinstance(s.multi_modal_data, list) + typed_mm = cast(list[dict[str, Any]], s.multi_modal_data) + for it in typed_mm: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=0, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 5, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + for s in samples: + assert s.multi_modal_data == [] + +@pytest.mark.benchmark +def test_random_mm_num_items_per_prompt( + hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Fixed number of images per prompt + # set num_mm_items_range_ratio to 0.0 + # TODO: modify video values when video sampling is implemented + samples_fixed_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 3, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with 3 mm items per prompt + assert len(samples_fixed_items) == 5 + for s in samples_fixed_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) == 3 + for it in mm_data: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_bucket_config_not_mutated( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + + ds = RandomMultiModalDataset(random_seed=0) + # This bucket config is not normalized to sum to 1 + # and has more buckets than requested images + original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3} + # Keep a snapshot to compare after sampling + snapshot = dict(original) + + _ = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=4, + base_items_per_request=1, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config=original, + ) + + # Ensure the original dict content is unchanged + assert original == snapshot + + + # Vary number of mm items per prompt + # set num_mm_items_range_ratio to 0.5 + samples_varying_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=2, + num_mm_items_range_ratio=0.5, + limit_mm_per_prompt={"image": 4, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with less than 4 mm items per prompt + # but at least 1 mm item per prompt + assert len(samples_varying_items) == 5 + for s in samples_varying_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) <= 4 + assert len(mm_data) >= 1 + for it in mm_data: + assert it.get("type") == "image_url" diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 920d21bda3c5b..e586337367b1c 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,18 +11,21 @@ generation. Supported dataset types include: - HuggingFace - VisionArena """ +import ast import base64 import io import json import logging +import math import random from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Iterator, Mapping +from contextlib import suppress from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast import numpy as np from PIL import Image @@ -114,7 +117,9 @@ class BenchmarkDataset(ABC): def apply_multimodal_chat_transformation( self, prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + mm_content: Optional[ + Union[MultiModalDataDict, dict, list[dict]] + ] = None) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -122,7 +127,15 @@ class BenchmarkDataset(ABC): """ content = [{"text": prompt, "type": "text"}] if mm_content is not None: - content.append(mm_content) + if isinstance(mm_content, list): + content.extend(cast(list[dict[str, Any]], mm_content)) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "Could not process multimodal content of type: " + + f"{type(mm_content)}" + ) return [{"role": "user", "content": content}] def load_data(self) -> None: @@ -362,90 +375,536 @@ def process_video(video: Any) -> Mapping[str, Any]: class RandomDataset(BenchmarkDataset): + """ + Synthetic text-only dataset for serving/throughput benchmarks. + + Strategy: + - Sample input/output token lengths per request from integer-uniform ranges + around configured means (controlled by range_ratio). + - Prepend a fixed random prefix of length prefix_len. + - Generate the remaining tokens as a reproducible sequence: + (offset + index + arange(input_len)) % vocab_size. + - Decode then re-encode/truncate to ensure prompt token counts match. + - Uses numpy.default_rng seeded with random_seed for reproducible sampling. + """ # Default values copied from benchmark_serving.py for the random dataset. DEFAULT_PREFIX_LEN = 0 DEFAULT_RANGE_RATIO = 0.0 DEFAULT_INPUT_LEN = 1024 DEFAULT_OUTPUT_LEN = 128 - def __init__( - self, - **kwargs, - ) -> None: + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - random.seed(self.random_seed) - np.random.seed(self.random_seed) + # Use numpy's default_rng for deterministic sampling + # Do not use random.seed() or np.random.seed() elsewhere in this class. + # This ensures that the RNG is isolated from global RNG state. + self._rng = np.random.default_rng(self.random_seed) def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", prefix_len: int = DEFAULT_PREFIX_LEN, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, - request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: - # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range" + + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer ) + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) vocab_size = tokenizer.vocab_size - num_special_tokens = tokenizer.num_special_tokens_to_add() - real_input_len = input_len - num_special_tokens - - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) - - # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(real_input_len * (1 - range_ratio)) - input_high = int(real_input_len * (1 + range_ratio)) - output_low = int(output_len * (1 - range_ratio)) - output_high = int(output_len * (1 + range_ratio)) - - # Add logging for debugging - logger.info( - "Sampling input_len from [%s, %s] and output_len from [%s, %s]", - input_low, input_high, output_low, output_high) - - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, - size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) requests = [] for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() - token_sequence = prefix_token_ids + inner_seq - prompt = tokenizer.decode(token_sequence) - # After decoding the prompt we have to encode and decode it again. - # This is done because in some cases N consecutive tokens - # give a string tokenized into != N number of tokens. - # For example for GPT2Tokenizer: - # [6880, 6881] -> ['Ġcalls', 'here'] -> - # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] - # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. - total_input_len = prefix_len + int(input_lens[i]) - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:total_input_len] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) requests.append( SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), request_id=request_id_prefix + str(i), - )) + ) + ) return requests + def get_prefix( + self, tokenizer: PreTrainedTokenizerBase, prefix_len: int + ) -> list[int]: + """ + Get the prefix for the dataset. + """ + return ( + self._rng.integers( + 0, tokenizer.vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) + + def get_sampling_params( + self, + num_requests: int, + range_ratio: float, + input_len: int, + output_len: int, + tokenizer: PreTrainedTokenizerBase, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get the sampling parameters for the dataset. + """ + # Enforce range_ratio < 1 + if not (0.0 <= range_ratio < 1.0): + raise ValueError("range_ratio must be in [0, 1).") + num_special_tokens = int(tokenizer.num_special_tokens_to_add()) + real_input_len = max(0, int(input_len) - num_special_tokens) + # Bounds use floor for low and ceil for high + input_low = math.floor(real_input_len * (1 - range_ratio)) + input_high = math.ceil(real_input_len * (1 + range_ratio)) + output_low = math.floor(output_len * (1 - range_ratio)) + output_high = math.ceil(output_len * (1 + range_ratio)) + # Ensure the lower bound for output length is at least 1 to + # prevent sampling 0 tokens. + output_low = max(output_low, 1) + + if input_low > input_high: + raise ValueError( + "Invalid input sampling interval: " + f"low={input_low} > high={input_high}" + ) + if output_low > output_high: + raise ValueError( + "Invalid output sampling interval: " + f"low={output_low} > high={output_high}" + ) + + logger.info( + "Sampling input_len from [%s, %s] and output_len from [%s, %s]", + input_low, + input_high, + output_low, + output_high, + ) + + input_lens = self._rng.integers(input_low, input_high + 1, + size=num_requests) + output_lens = self._rng.integers(output_low, output_high + 1, + size=num_requests) + offsets = self._rng.integers(0, tokenizer.vocab_size, + size=num_requests) + return input_lens, output_lens, offsets + + + def generate_token_sequence( + self, + *, + tokenizer: PreTrainedTokenizerBase, + prefix_token_ids: list[int], + prefix_len: int, + vocab_size: int, + input_len: int, + offset: int, + index: int, + ) -> tuple[str, int]: + """ + Returns (prompt, total_input_len). + + NOTE: After decoding the prompt we have to encode and decode it again. + This is done because in some cases N consecutive tokens + give a string tokenized into != N number of tokens. + For example for GPT2Tokenizer: + [6880, 6881] -> ['Ġcalls', 'here'] -> + [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + To avoid uncontrolled change of the prompt length, + the encoded sequence is truncated before being decode again. + """ + # Build the inner sequence by sampling sequentially from the vocab + inner_seq = ((offset + index + np.arange(input_len)) + % vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + + # Decode, then re-encode and truncate to preserve token count invariants + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_len) + + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:total_input_len] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = len(re_encoded_sequence) + + return prompt, total_input_len + + +# ----------------------------------------------------------------------------- +# MultiModalDataset Implementation +# ----------------------------------------------------------------------------- + +class RandomMultiModalDataset(RandomDataset): + """ + Synthetic multimodal dataset (text + images) that extends RandomDataset. + + Status: + - Images: supported via synthetic RGB data. + - Video: not yet supported (TODO: implement video generation method). + - Audio: not yet supported. + + Sampling overview: + 1) Number of items per request is sampled uniformly from the integer range + [floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is + `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. + The maximum is further clamped to the sum of per-modality limits. + 2) Each item’s modality and shape is sampled from `bucket_config`, a dict + mapping (height, width, num_frames) → probability. We treat + `num_frames`=1 as image and and `num_frames` > 1 as video. + Entries with zero probability are removed and the rest are renormalized + to sum to 1. + 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. + When a modality reaches its cap, all of its buckets are excluded and the + remaining probabilities are renormalized. + + Example bucket configuration: + {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} + - Two image buckets (`num_frames`=1) and one video bucket + (`num_frames`=16). + OBS.: Only image sampling is supported for now. + """ + + IS_MULTIMODAL = True + # NOTE: video sampling is WIP. Setting it to 0. + DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0} + + DEFAULT_BASE_ITEMS_PER_REQUEST = 1 + DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0 + DEFAULT_MM_ITEM_BUCKET_CONFIG = { + (256, 256, 1): 0.5, + (720, 1280, 1): 0.5, + (720, 1280, 16): 0.0, + } + DEFAULT_ENABLE_MULTIMODAL_CHAT = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + + def generate_synthetic_image(self, width: int, height: int) -> Image.Image: + """Generate synthetic PIL image with random RGB values. + + NOTE: iid pixel sampling results in worst-case compression + (good for stressing I/O), but very unlike real photos. + We could consider a “low-freq” mode (e.g., noise blur) + to emulate network realism instead of max stress. + """ + random_pixels = self._rng.integers( + 0, + 256, + (height, width, 3), + dtype=np.uint8, + ) + return Image.fromarray(random_pixels) + + def generate_synthetic_video(self, width: int, + height: int, + num_frames: int) -> Any: + """Generate synthetic video with random values. + + TODO: Finish this method. + """ + raise NotImplementedError("Video sampling is WIP.") + + def map_config_to_modality(self, config: tuple[int, int, int]) -> str: + """Map the configuration to the modality.""" + if config[-1] == 1: + return "image" + elif config[-1] > 1: + return "video" + else: + raise ValueError(f"Invalid multimodal item configuration: {config}") + + def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], + float]) -> dict[tuple[int, int, int], float]: + """ + Remove zero probability entries + and normalize the bucket config to sum to 1. + """ + # Raise error if value is negative + if any(v < 0 for v in bucket_config.values()): + raise ValueError("Bucket config values must be non-negative.") + # Remove zero probability entries + bucket_config = {k: v for k, v in bucket_config.items() if v > 0} + # if bucket config is empty, raise error + if not bucket_config: + raise ValueError("Got invalid bucket config. " + "Bucket config values must be non-zero.") + # Normalize the remaining bucket config to sum to 1 + total = sum(bucket_config.values()) + return {k: v / total for k, v in bucket_config.items()} + + + def generate_mm_item(self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: + """ + Create synthetic images and videos and + apply process_image/process_video respectively. + This follows the OpenAI API chat completions + https://github.com/openai/openai-python + """ + + if self.map_config_to_modality(mm_item_config) == "image": + return process_image(self.generate_synthetic_image( + mm_item_config[1], + mm_item_config[0])) + elif self.map_config_to_modality(mm_item_config) == "video": + return process_video(self.generate_synthetic_video( + mm_item_config[1], + mm_item_config[0], + mm_item_config[2])) + else: + raise ValueError(f"Invalid multimodal item configuration: " + f"{mm_item_config}") + + + def get_mm_item_sampling_params( + self, + base_items_per_request: int, + num_mm_items_range_ratio: float, + limit_mm_per_prompt: dict[str, int], + bucket_config: dict[tuple[int, int, int], float], + ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]: + """ + Get the sampling parameters for the multimodal items. + """ + # Enforce num_mm_items_range_ratio <= 1 + if not (0.0 <= num_mm_items_range_ratio <= 1.0): + raise ValueError("num_mm_items_range_ratio must be in [0, 1].") + + # Ensure modalities to sample are in limit_mm_per_prompt + for k, v in bucket_config.items(): + # get modality from bucket config + modality = self.map_config_to_modality(k) + if modality not in limit_mm_per_prompt: + raise ValueError(f"Modality {modality} is not in " + f"limit_mm_per_prompt: " + f"{limit_mm_per_prompt.keys()}") + + # Remove zero probability entries + # and normalize bucket config to sum to 1 + bucket_config = self.normalize_bucket_config(bucket_config) + logger.info( + "Normalized bucket config: %s", bucket_config, + ) + # Only consider limit per prompt for modalities in bucket config + allowed_modalities = {self.map_config_to_modality(cfg) + for cfg in bucket_config} + limit_mm_per_prompt = { + k: v for k, v in limit_mm_per_prompt.items() + if k in allowed_modalities} + if not limit_mm_per_prompt: + raise ValueError("No valid limits for modalities present in " + "bucket_config.") + + logger.info( + "Updated mm-limit-per-prompt: %s", limit_mm_per_prompt, + ) + + # Get max and min num mm items and ensure + # it is at most the sum of limit_mm_per_prompt for all modalities + max_num_mm_items = min( + sum(limit_mm_per_prompt.values()), + math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) + ) + # Ensure min num mm items is at least 0 + min_num_mm_items = max( + 0, + math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) + ) + # Raise error if min num mm items is greater than max num mm items + if min_num_mm_items > max_num_mm_items: + raise ValueError(f"Min num mm items is greater than max mm items: " + f"{min_num_mm_items} > {max_num_mm_items}") + + logger.info( + "Sampling number of multimodal items from [%s, %s]", + min_num_mm_items, max_num_mm_items, + ) + + return ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) + + def get_mm_item_iterator( + self, + min_num_mm_items: int, + max_num_mm_items: int, + bucket_config: dict[tuple[int, int, int], float], + limit_mm_per_prompt: dict[str, int], + ) -> Iterator[tuple[int,int, int]]: + """ + Iterator over the multimodal items for each request + whose size is between min_num_mm_items and max_num_mm_items. + + Loop over the bucket config and sample a multimodal item. + Loop until the number of multimodal items sampled is equal to + request_num_mm_items or limit of multimodal items per prompt + for all modalities is reached. + + Note: + - This function operates on a per-request shallow copy of + `bucket_config` (tuple->float). The original dict passed to + `sample` is not mutated. If this ever changes, a test + is implemented and will fail. + """ + # Get the number of multimodal items to sample + request_num_mm_items = int( + self._rng.integers(min_num_mm_items, max_num_mm_items + 1) + ) + # If request_num_mm_items is 0, yield an empty iterator + if request_num_mm_items == 0: + return + # Initialize modality counters + modality_counter = {self.map_config_to_modality(k): 0 + for k in bucket_config} + # Copy the bucket config to avoid modifying the original + bucket_config_copy = bucket_config.copy() + # Loop over the number of multimodal items to sample + while sum(modality_counter.values()) < request_num_mm_items: + # Sample a multimodal item config + mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), + p=list(bucket_config_copy.values())) + modality = self.map_config_to_modality(mm_item_config) + # Check that modality count is less than limit per prompt + if modality_counter[modality] < limit_mm_per_prompt[modality]: + modality_counter[modality] += 1 + yield ( + mm_item_config + ) + else: + # If the counter is greater than the limit per prompt + # set all multimodal items of this modality to 0 + for k, v in bucket_config_copy.items(): + if self.map_config_to_modality(k) == modality: + bucket_config_copy[k] = 0 + # If all configs are 0, break the loop + # This should not happen as request_num_mm_items is at most + # the sum of limit_mm_per_prompt for all modalities + if all(v == 0 for v in bucket_config_copy.values()): + logger.warning("Exhausted all multimodal items " + "of modality %s", + modality) + break + # Renormalize the bucket config + bucket_config_copy = self.normalize_bucket_config( + bucket_config_copy) + + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, + limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, + base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, + num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + bucket_config: dict[tuple[int, int, int], float] = + DEFAULT_MM_ITEM_BUCKET_CONFIG, + enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, + **kwargs, + ) -> list[SampleRequest]: + + # NOTE: Video sampling is WIP. Raise error if video is in bucket config + # and probability is non-zero. + if any(self.map_config_to_modality(cfg) == "video" and p > 0 + for cfg, p in bucket_config.items()): + raise NotImplementedError("Video sampling not implemented; " + "set its probability to 0.") + + # Get the sampling parameters for the dataset + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer + ) + + ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) = self.get_mm_item_sampling_params( + base_items_per_request, + num_mm_items_range_ratio, + limit_mm_per_prompt, + bucket_config, + ) + + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) + vocab_size = tokenizer.vocab_size + # Add synthetic multimodal items to each request + mm_requests = [] + for i in range(num_requests): + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + # Get multimodal item iterator for a given request + mm_item_iterator = self.get_mm_item_iterator( + min_num_mm_items, + max_num_mm_items, + bucket_config, + limit_mm_per_prompt, + ) + + mm_content = cast(list[dict[str, Any]], [ + self.generate_mm_item(mm_item_config) + for mm_item_config in mm_item_iterator + ]) + + if enable_multimodal_chat: + # NOTE: For now this option is only provided for completeness + # given that the serve.py benchmark currently does not use it. + mm_chat_prompt: Any = prompt + mm_chat_prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sample_request = SampleRequest( + prompt=mm_chat_prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=None, + request_id=request_id_prefix + str(i), + ) + else: + sample_request = SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=mm_content, + request_id=request_id_prefix + str(i), + ) + mm_requests.append(sample_request) + return mm_requests # ----------------------------------------------------------------------------- # ShareGPT Dataset Implementation @@ -545,8 +1004,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser): type=str, default="random", choices=[ - "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom", - "prefix_repetition" + "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", + "custom", "prefix_repetition" ], help="Name of the dataset to benchmark on.", ) @@ -647,6 +1106,98 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "input_len * (1 + range_ratio)]."), ) + # random multimodal dataset options + random_mm_group = parser.add_argument_group( + "random multimodal dataset options extended from random dataset") + random_mm_group.add_argument( + "--random-mm-base-items-per-request", + type=int, + default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST, + help=( + "Base number of multimodal items per request for random-mm. " + "Actual per-request count is sampled around this base using " + "--random-mm-num-mm-items-range-ratio." + ), + ) + random_mm_group.add_argument( + "--random-mm-num-mm-items-range-ratio", + type=float, + default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + help=( + "Range ratio r in [0, 1] for sampling items per request. " + "We sample uniformly from the closed integer range " + "[floor(n*(1-r)), ceil(n*(1+r))] " + "where n is the base items per request. " + "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped " + "to the sum of per-modality limits from " + "--random-mm-limit-mm-per-prompt. " + "An error is raised if the computed min exceeds the max." + ), + ) + random_mm_group.add_argument( + "--random-mm-limit-mm-per-prompt", + type=json.loads, + default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, + help=( + "Per-modality hard caps for items attached per request, e.g. " + "'{\"image\": 3, \"video\": 0}'. The sampled per-request item " + "count is clamped to the sum of these limits. When a modality " + "reaches its cap, its buckets are excluded and probabilities are " + "renormalized." + "OBS.: Only image sampling is supported for now." + ), + ) + + def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]: + # If already a dict (e.g., programmatic call), normalize keys + def normalize(d: dict) -> dict[tuple[int, int, int], float]: + out: dict[tuple[int, int, int], float] = {} + for k, val in d.items(): + key = k + if isinstance(key, str): + with suppress(Exception): + key = ast.literal_eval(key) + if not (isinstance(key, tuple) and len(key) == 3 + and all(isinstance(x, int) for x in key)): + raise ValueError( + f"Invalid bucket key {k!r}. Expected tuple (H, W, T)." + ) + out[(int(key[0]), int(key[1]), int(key[2]))] = float(val) + return out + + if isinstance(v, dict): + return normalize(v) + if isinstance(v, str): + # Python literal (supports tuple keys) + parsed = ast.literal_eval(v) + if not isinstance(parsed, dict): + raise ValueError("Bucket config must parse to a dict.") + return normalize(parsed) + raise ValueError("Unsupported value for --random-mm-bucket-config.") + + random_mm_group.add_argument( + "--random-mm-bucket-config", + type=_parse_mm_bucket_config, + default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG, + help=( + "The bucket config is a dictionary mapping a multimodal item" + "sampling configuration to a probability." + "Currently allows for 2 modalities: images and videos. " + "An bucket key is a tuple of (height, width, num_frames)" + "The value is the probability of sampling that specific item. " + "Example: " + "--random-mm-bucket-config " + "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} " + "First item: images with resolution 256x256 w.p. 0.5" + "Second item: images with resolution 720x1280 w.p. 0.4 " + "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1" + "OBS.: If the probabilities do not sum to 1, they are normalized." + "OBS bis.: Only image sampling is supported for now." + ), + ) + + + hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset", type=str, @@ -821,6 +1372,22 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: range_ratio=args.random_range_ratio, request_id_prefix=args.request_id_prefix, ), + "random-mm": + lambda: RandomMultiModalDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + range_ratio=args.random_range_ratio, + input_len=args.random_input_len, + output_len=args.random_output_len, + base_items_per_request=args.random_mm_base_items_per_request, + limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, + num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, + bucket_config=args.random_mm_bucket_config, + request_id_prefix=args.request_id_prefix, + ), "prefix_repetition": lambda: PrefixRepetitionRandomDataset( random_seed=args.seed, dataset_path=args.dataset_path @@ -836,6 +1403,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: } try: + # Enforce endpoint compatibility for multimodal datasets. + if args.dataset_name == "random-mm" and args.endpoint_type not in [ + "openai-chat"]: + raise ValueError( + "Multi-modal content (images) is only supported on " + "'openai-chat' backend." + ) input_requests = dataset_mapping[args.dataset_name]() except KeyError as err: raise ValueError(f"Unknown dataset: {args.dataset_name}") from err