diff --git a/benchmarks/README.md b/benchmarks/README.md
index 176b40212978f..a2dd5bb58325c 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -59,6 +59,12 @@ become available.
✅ |
synthetic |
+
+ | RandomMultiModal (Image/Video) |
+ 🟡 |
+ 🚧 |
+ synthetic |
+
| Prefix Repetition |
✅ |
@@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \
--endpoint /v1/chat/completion
```
+### Synthetic Random Images (random-mm)
+
+Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
+
+Notes:
+
+- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
+- Video sampling is not yet implemented.
+
+Start the server (example):
+
+```bash
+vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
+ --dtype bfloat16 \
+ --max-model-len 16384 \
+ --limit-mm-per-prompt '{"image": 3, "video": 0}' \
+ --mm-processor-kwargs max_pixels=1003520
+```
+
+Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
+
+Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens:
+
+```bash
+vllm bench serve \
+ --backend openai-chat \
+ --model Qwen/Qwen2.5-VL-3B-Instruct \
+ --endpoint /v1/chat/completions \
+ --dataset-name random-mm \
+ --num-prompts 100 \
+ --max-concurrency 10 \
+ --random-prefix-len 25 \
+ --random-input-len 300 \
+ --random-output-len 40 \
+ --random-range-ratio 0.2 \
+ --random-mm-base-items-per-request 2 \
+ --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
+ --random-mm-bucket-config '{(224, 224, 1): 1.0}' \
+ --request-rate inf \
+ --ignore-eos \
+ --seed 42
+```
+
+The number of items per request can be controlled by passing multiple image buckets:
+
+```bash
+ --random-mm-base-items-per-request 2 \
+ --random-mm-num-mm-items-range-ratio 0.5 \
+ --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
+ --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
+```
+
+Flags specific to `random-mm`:
+
+- `--random-mm-base-items-per-request`: base number of multimodal items per request.
+- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
+- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
+- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
+
+Behavioral notes:
+
+- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
+
+How sampling works:
+
+- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
+- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
+- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
+This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
+- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
+
diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py
new file mode 100644
index 0000000000000..26cae369cdd5d
--- /dev/null
+++ b/tests/benchmarks/test_random_dataset.py
@@ -0,0 +1,344 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import random
+from typing import Any, NamedTuple, Optional, cast
+
+import numpy as np
+import pytest
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+
+from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
+ SampleRequest)
+
+
+@pytest.fixture(scope="session")
+def hf_tokenizer() -> PreTrainedTokenizerBase:
+ # Use a small, commonly available tokenizer
+ return AutoTokenizer.from_pretrained("gpt2")
+
+
+class Params(NamedTuple):
+ num_requests: int
+ prefix_len: int
+ range_ratio: float
+ input_len: int
+ output_len: int
+
+
+@pytest.fixture(scope="session")
+def random_dataset_params() -> Params:
+ return Params(num_requests=16,
+ prefix_len=7,
+ range_ratio=0.3,
+ input_len=50,
+ output_len=20)
+
+
+def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
+ """Project a SampleRequest into a comparable tuple."""
+ return (req.prompt, req.prompt_len, req.expected_output_len)
+
+
+def _collect_samples(dataset: RandomDataset,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int = 16,
+ prefix_len: int = 7,
+ range_ratio: float = 0.3,
+ input_len: int = 50,
+ output_len: int = 20) -> list[tuple[str, int, int]]:
+ samples = dataset.sample(
+ tokenizer=tokenizer,
+ num_requests=num_requests,
+ prefix_len=prefix_len,
+ range_ratio=range_ratio,
+ input_len=input_len,
+ output_len=output_len,
+ )
+ return [_fingerprint_sample(s) for s in samples]
+
+
+@pytest.mark.benchmark
+def test_random_dataset_same_seed(
+ hf_tokenizer: PreTrainedTokenizerBase,
+ random_dataset_params: Params) -> None:
+ """Same seed should yield identical outputs, even if global RNGs change.
+
+ This guards against accidental reliance on Python's random or np.random
+ in RandomDataset after moving to numpy.default_rng.
+ """
+ p = random_dataset_params
+ common_seed = 123
+ dataset_a = RandomDataset(random_seed=common_seed)
+ dataset_b = RandomDataset(random_seed=common_seed)
+ a = _collect_samples(dataset_a,
+ hf_tokenizer,
+ num_requests=p.num_requests,
+ prefix_len=p.prefix_len,
+ range_ratio=p.range_ratio,
+ input_len=p.input_len,
+ output_len=p.output_len)
+
+ # Perturb global RNG state to ensure isolation
+ random.seed(999)
+ _ = [random.random() for _ in range(100)]
+ np.random.seed(888)
+ _ = [np.random.random() for _ in range(100)]
+
+ b = _collect_samples(dataset_b,
+ hf_tokenizer,
+ num_requests=p.num_requests,
+ prefix_len=p.prefix_len,
+ range_ratio=p.range_ratio,
+ input_len=p.input_len,
+ output_len=p.output_len)
+ assert a == b
+
+@pytest.mark.benchmark
+def test_random_dataset_different_seeds(
+ hf_tokenizer: PreTrainedTokenizerBase,
+ random_dataset_params: Params) -> None:
+ """Different seeds should change outputs with overwhelming likelihood."""
+ p = random_dataset_params
+ seed_a = 0
+ dataset_a = RandomDataset(random_seed=seed_a)
+ a = _collect_samples(dataset_a,
+ hf_tokenizer,
+ num_requests=p.num_requests,
+ prefix_len=p.prefix_len,
+ range_ratio=p.range_ratio,
+ input_len=p.input_len,
+ output_len=p.output_len)
+
+ seed_b = 999
+ dataset_b = RandomDataset(random_seed=seed_b)
+ # Perturb global RNG with same seed as dataset_a to ensure isolation
+ random.seed(seed_a)
+ np.random.seed(seed_a)
+ b = _collect_samples(dataset_b,
+ hf_tokenizer,
+ num_requests=p.num_requests,
+ prefix_len=p.prefix_len,
+ range_ratio=p.range_ratio,
+ input_len=p.input_len,
+ output_len=p.output_len)
+ assert a != b
+
+
+# -----------------------------
+# RandomMultiModalDataset tests
+# -----------------------------
+
+def _mm_fingerprint_sample(
+ req: SampleRequest,
+) -> tuple[str, int, int, int, list[str]]:
+ """Create a compact fingerprint for multimodal samples.
+
+ Includes:
+ - prompt string
+ - prompt_len
+ - expected_output_len
+ - count of multimodal items
+ - per-item type and URL prefix (e.g., 'data:image/jpeg;base64,')
+ """
+ items = req.multi_modal_data or []
+ item_prefixes: list[str] = []
+ for it in items:
+ if isinstance(it, dict) and it.get("type") == "image_url":
+ url = it.get("image_url", {}).get("url", "")
+ # Only keep a short identifying prefix to avoid huge strings
+ item_prefixes.append(f"image:{url[:22]}")
+ elif isinstance(it, dict) and it.get("type") == "video_url":
+ url = it.get("video_url", {}).get("url", "")
+ item_prefixes.append(f"video:{url[:22]}")
+ else:
+ item_prefixes.append("unknown:")
+ return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
+ item_prefixes)
+
+
+def _collect_mm_samples(
+ dataset: RandomMultiModalDataset,
+ tokenizer: PreTrainedTokenizerBase,
+ *,
+ num_requests: int = 8,
+ prefix_len: int = 3,
+ range_ratio: float = 0.0,
+ input_len: int = 20,
+ output_len: int = 5,
+ base_items_per_request: int = 2,
+ num_mm_items_range_ratio: float = 0.0,
+ limit_mm_per_prompt: Optional[dict[str, int]] = None,
+ bucket_config: Optional[dict[tuple[int, int, int], float]] = None,
+ enable_multimodal_chat: bool = False,
+) -> list[SampleRequest]:
+ if limit_mm_per_prompt is None:
+ limit_mm_per_prompt = {"image": 5, "video": 0}
+ if bucket_config is None:
+ bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5}
+ return dataset.sample(
+ tokenizer=tokenizer,
+ num_requests=num_requests,
+ prefix_len=prefix_len,
+ range_ratio=range_ratio,
+ input_len=input_len,
+ output_len=output_len,
+ base_items_per_request=base_items_per_request,
+ num_mm_items_range_ratio=num_mm_items_range_ratio,
+ limit_mm_per_prompt=limit_mm_per_prompt,
+ bucket_config=bucket_config,
+ enable_multimodal_chat=enable_multimodal_chat,
+ )
+
+
+@pytest.mark.benchmark
+def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None:
+ seed = 42
+ ds_a = RandomMultiModalDataset(random_seed=seed)
+ ds_b = RandomMultiModalDataset(random_seed=seed)
+ a = _collect_mm_samples(ds_a, hf_tokenizer)
+ b = _collect_mm_samples(ds_b, hf_tokenizer)
+ fa = [_mm_fingerprint_sample(s) for s in a]
+ fb = [_mm_fingerprint_sample(s) for s in b]
+ assert fa == fb
+
+
+@pytest.mark.benchmark
+def test_random_mm_different_seeds(
+ hf_tokenizer: PreTrainedTokenizerBase,
+) -> None:
+ ds_a = RandomMultiModalDataset(random_seed=0)
+ ds_b = RandomMultiModalDataset(random_seed=999)
+ a = _collect_mm_samples(ds_a, hf_tokenizer)
+ b = _collect_mm_samples(ds_b, hf_tokenizer)
+ fa = [_mm_fingerprint_sample(s) for s in a]
+ fb = [_mm_fingerprint_sample(s) for s in b]
+ assert fa != fb
+
+@pytest.mark.benchmark
+def test_random_mm_respects_limits(
+ hf_tokenizer: PreTrainedTokenizerBase,
+) -> None:
+ ds = RandomMultiModalDataset(random_seed=0)
+ # Requesting 3 items with a per-prompt limit of 1 should error per current
+ # design (dataset refuses to silently clamp below the requested baseline).
+ with pytest.raises(ValueError):
+ _collect_mm_samples(
+ ds,
+ hf_tokenizer,
+ num_requests=12,
+ base_items_per_request=3,
+ num_mm_items_range_ratio=0.0,
+ limit_mm_per_prompt={"image": 1, "video": 0},
+ bucket_config={(32, 32, 1): 1.0},
+ )
+
+
+@pytest.mark.benchmark
+def test_random_mm_zero_prob_entries_are_removed(
+ hf_tokenizer: PreTrainedTokenizerBase,
+) -> None:
+ ds = RandomMultiModalDataset(random_seed=0)
+ # Second bucket has zero probability and should be ignored after
+ # normalization
+ samples = _collect_mm_samples(
+ ds,
+ hf_tokenizer,
+ num_requests=6,
+ base_items_per_request=2,
+ num_mm_items_range_ratio=0.0,
+ limit_mm_per_prompt={"image": 10, "video": 0},
+ bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0},
+ )
+ for s in samples:
+ assert isinstance(s.multi_modal_data, list)
+ typed_mm = cast(list[dict[str, Any]], s.multi_modal_data)
+ for it in typed_mm:
+ assert it.get("type") == "image_url"
+
+
+@pytest.mark.benchmark
+def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
+ ds = RandomMultiModalDataset(random_seed=0)
+ samples = _collect_mm_samples(
+ ds,
+ hf_tokenizer,
+ num_requests=5,
+ base_items_per_request=0,
+ num_mm_items_range_ratio=0.0,
+ limit_mm_per_prompt={"image": 5, "video": 0},
+ bucket_config={(32, 32, 1): 1.0},
+ )
+ for s in samples:
+ assert s.multi_modal_data == []
+
+@pytest.mark.benchmark
+def test_random_mm_num_items_per_prompt(
+ hf_tokenizer: PreTrainedTokenizerBase) -> None:
+ ds = RandomMultiModalDataset(random_seed=0)
+ # Fixed number of images per prompt
+ # set num_mm_items_range_ratio to 0.0
+ # TODO: modify video values when video sampling is implemented
+ samples_fixed_items = _collect_mm_samples(
+ ds,
+ hf_tokenizer,
+ num_requests=5,
+ base_items_per_request=3,
+ num_mm_items_range_ratio=0.0,
+ limit_mm_per_prompt={"image": 3, "video": 0},
+ bucket_config={(32, 32, 1): 1.0},
+ )
+ # Must have 5 requests each with 3 mm items per prompt
+ assert len(samples_fixed_items) == 5
+ for s in samples_fixed_items:
+ mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
+ assert len(mm_data) == 3
+ for it in mm_data:
+ assert it.get("type") == "image_url"
+
+
+@pytest.mark.benchmark
+def test_random_mm_bucket_config_not_mutated(
+ hf_tokenizer: PreTrainedTokenizerBase,
+) -> None:
+
+ ds = RandomMultiModalDataset(random_seed=0)
+ # This bucket config is not normalized to sum to 1
+ # and has more buckets than requested images
+ original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3}
+ # Keep a snapshot to compare after sampling
+ snapshot = dict(original)
+
+ _ = _collect_mm_samples(
+ ds,
+ hf_tokenizer,
+ num_requests=4,
+ base_items_per_request=1,
+ num_mm_items_range_ratio=0.0,
+ limit_mm_per_prompt={"image": 1, "video": 0},
+ bucket_config=original,
+ )
+
+ # Ensure the original dict content is unchanged
+ assert original == snapshot
+
+
+ # Vary number of mm items per prompt
+ # set num_mm_items_range_ratio to 0.5
+ samples_varying_items = _collect_mm_samples(
+ ds,
+ hf_tokenizer,
+ num_requests=5,
+ base_items_per_request=2,
+ num_mm_items_range_ratio=0.5,
+ limit_mm_per_prompt={"image": 4, "video": 0},
+ bucket_config={(32, 32, 1): 1.0},
+ )
+ # Must have 5 requests each with less than 4 mm items per prompt
+ # but at least 1 mm item per prompt
+ assert len(samples_varying_items) == 5
+ for s in samples_varying_items:
+ mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
+ assert len(mm_data) <= 4
+ assert len(mm_data) >= 1
+ for it in mm_data:
+ assert it.get("type") == "image_url"
diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py
index 920d21bda3c5b..e586337367b1c 100644
--- a/vllm/benchmarks/datasets.py
+++ b/vllm/benchmarks/datasets.py
@@ -11,18 +11,21 @@ generation. Supported dataset types include:
- HuggingFace
- VisionArena
"""
+import ast
import base64
import io
import json
import logging
+import math
import random
from abc import ABC, abstractmethod
-from collections.abc import Mapping
+from collections.abc import Iterator, Mapping
+from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass
from functools import cache
from io import BytesIO
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, Optional, Union, cast
import numpy as np
from PIL import Image
@@ -114,7 +117,9 @@ class BenchmarkDataset(ABC):
def apply_multimodal_chat_transformation(
self,
prompt: str,
- mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
+ mm_content: Optional[
+ Union[MultiModalDataDict, dict, list[dict]]
+ ] = None) -> list[dict]:
"""
Transform a prompt and optional multimodal content into a chat format.
This method is used for chat models that expect a specific conversation
@@ -122,7 +127,15 @@ class BenchmarkDataset(ABC):
"""
content = [{"text": prompt, "type": "text"}]
if mm_content is not None:
- content.append(mm_content)
+ if isinstance(mm_content, list):
+ content.extend(cast(list[dict[str, Any]], mm_content))
+ elif isinstance(mm_content, dict):
+ content.append(mm_content)
+ else:
+ raise TypeError(
+ "Could not process multimodal content of type: " +
+ f"{type(mm_content)}"
+ )
return [{"role": "user", "content": content}]
def load_data(self) -> None:
@@ -362,90 +375,536 @@ def process_video(video: Any) -> Mapping[str, Any]:
class RandomDataset(BenchmarkDataset):
+ """
+ Synthetic text-only dataset for serving/throughput benchmarks.
+
+ Strategy:
+ - Sample input/output token lengths per request from integer-uniform ranges
+ around configured means (controlled by range_ratio).
+ - Prepend a fixed random prefix of length prefix_len.
+ - Generate the remaining tokens as a reproducible sequence:
+ (offset + index + arange(input_len)) % vocab_size.
+ - Decode then re-encode/truncate to ensure prompt token counts match.
+ - Uses numpy.default_rng seeded with random_seed for reproducible sampling.
+ """
# Default values copied from benchmark_serving.py for the random dataset.
DEFAULT_PREFIX_LEN = 0
DEFAULT_RANGE_RATIO = 0.0
DEFAULT_INPUT_LEN = 1024
DEFAULT_OUTPUT_LEN = 128
- def __init__(
- self,
- **kwargs,
- ) -> None:
+ def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
- random.seed(self.random_seed)
- np.random.seed(self.random_seed)
+ # Use numpy's default_rng for deterministic sampling
+ # Do not use random.seed() or np.random.seed() elsewhere in this class.
+ # This ensures that the RNG is isolated from global RNG state.
+ self._rng = np.random.default_rng(self.random_seed)
def sample(
self,
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
+ request_id_prefix: str = "",
prefix_len: int = DEFAULT_PREFIX_LEN,
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
- request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
- # Enforce range_ratio < 1
- assert range_ratio < 1.0, (
- "random_range_ratio must be < 1.0 to ensure a valid sampling range"
+
+ input_lens, output_lens, offsets = self.get_sampling_params(
+ num_requests, range_ratio, input_len, output_len, tokenizer
)
+ # Generate prefix once
+ prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
vocab_size = tokenizer.vocab_size
- num_special_tokens = tokenizer.num_special_tokens_to_add()
- real_input_len = input_len - num_special_tokens
-
- prefix_token_ids = (np.random.randint(
- 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
-
- # New sampling logic: [X * (1 - b), X * (1 + b)]
- input_low = int(real_input_len * (1 - range_ratio))
- input_high = int(real_input_len * (1 + range_ratio))
- output_low = int(output_len * (1 - range_ratio))
- output_high = int(output_len * (1 + range_ratio))
-
- # Add logging for debugging
- logger.info(
- "Sampling input_len from [%s, %s] and output_len from [%s, %s]",
- input_low, input_high, output_low, output_high)
-
- input_lens = np.random.randint(input_low,
- input_high + 1,
- size=num_requests)
- output_lens = np.random.randint(output_low,
- output_high + 1,
- size=num_requests)
- offsets = np.random.randint(0, vocab_size, size=num_requests)
requests = []
for i in range(num_requests):
- inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
- vocab_size).tolist()
- token_sequence = prefix_token_ids + inner_seq
- prompt = tokenizer.decode(token_sequence)
- # After decoding the prompt we have to encode and decode it again.
- # This is done because in some cases N consecutive tokens
- # give a string tokenized into != N number of tokens.
- # For example for GPT2Tokenizer:
- # [6880, 6881] -> ['Ġcalls', 'here'] ->
- # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
- # To avoid uncontrolled change of the prompt length,
- # the encoded sequence is truncated before being decode again.
- total_input_len = prefix_len + int(input_lens[i])
- re_encoded_sequence = tokenizer.encode(
- prompt, add_special_tokens=False)[:total_input_len]
- prompt = tokenizer.decode(re_encoded_sequence)
- total_input_len = len(re_encoded_sequence)
+ prompt, total_input_len = self.generate_token_sequence(
+ tokenizer=tokenizer,
+ prefix_token_ids=prefix_token_ids,
+ prefix_len=prefix_len,
+ vocab_size=vocab_size,
+ input_len=int(input_lens[i]),
+ offset=int(offsets[i]),
+ index=i,
+ )
requests.append(
SampleRequest(
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
request_id=request_id_prefix + str(i),
- ))
+ )
+ )
return requests
+ def get_prefix(
+ self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
+ ) -> list[int]:
+ """
+ Get the prefix for the dataset.
+ """
+ return (
+ self._rng.integers(
+ 0, tokenizer.vocab_size, size=prefix_len).tolist()
+ if prefix_len > 0
+ else []
+ )
+
+ def get_sampling_params(
+ self,
+ num_requests: int,
+ range_ratio: float,
+ input_len: int,
+ output_len: int,
+ tokenizer: PreTrainedTokenizerBase,
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+ """
+ Get the sampling parameters for the dataset.
+ """
+ # Enforce range_ratio < 1
+ if not (0.0 <= range_ratio < 1.0):
+ raise ValueError("range_ratio must be in [0, 1).")
+ num_special_tokens = int(tokenizer.num_special_tokens_to_add())
+ real_input_len = max(0, int(input_len) - num_special_tokens)
+ # Bounds use floor for low and ceil for high
+ input_low = math.floor(real_input_len * (1 - range_ratio))
+ input_high = math.ceil(real_input_len * (1 + range_ratio))
+ output_low = math.floor(output_len * (1 - range_ratio))
+ output_high = math.ceil(output_len * (1 + range_ratio))
+ # Ensure the lower bound for output length is at least 1 to
+ # prevent sampling 0 tokens.
+ output_low = max(output_low, 1)
+
+ if input_low > input_high:
+ raise ValueError(
+ "Invalid input sampling interval: "
+ f"low={input_low} > high={input_high}"
+ )
+ if output_low > output_high:
+ raise ValueError(
+ "Invalid output sampling interval: "
+ f"low={output_low} > high={output_high}"
+ )
+
+ logger.info(
+ "Sampling input_len from [%s, %s] and output_len from [%s, %s]",
+ input_low,
+ input_high,
+ output_low,
+ output_high,
+ )
+
+ input_lens = self._rng.integers(input_low, input_high + 1,
+ size=num_requests)
+ output_lens = self._rng.integers(output_low, output_high + 1,
+ size=num_requests)
+ offsets = self._rng.integers(0, tokenizer.vocab_size,
+ size=num_requests)
+ return input_lens, output_lens, offsets
+
+
+ def generate_token_sequence(
+ self,
+ *,
+ tokenizer: PreTrainedTokenizerBase,
+ prefix_token_ids: list[int],
+ prefix_len: int,
+ vocab_size: int,
+ input_len: int,
+ offset: int,
+ index: int,
+ ) -> tuple[str, int]:
+ """
+ Returns (prompt, total_input_len).
+
+ NOTE: After decoding the prompt we have to encode and decode it again.
+ This is done because in some cases N consecutive tokens
+ give a string tokenized into != N number of tokens.
+ For example for GPT2Tokenizer:
+ [6880, 6881] -> ['Ġcalls', 'here'] ->
+ [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
+ To avoid uncontrolled change of the prompt length,
+ the encoded sequence is truncated before being decode again.
+ """
+ # Build the inner sequence by sampling sequentially from the vocab
+ inner_seq = ((offset + index + np.arange(input_len))
+ % vocab_size).tolist()
+ token_sequence = prefix_token_ids + inner_seq
+
+ # Decode, then re-encode and truncate to preserve token count invariants
+ prompt = tokenizer.decode(token_sequence)
+ total_input_len = prefix_len + int(input_len)
+
+ re_encoded_sequence = tokenizer.encode(
+ prompt, add_special_tokens=False)[:total_input_len]
+ prompt = tokenizer.decode(re_encoded_sequence)
+ total_input_len = len(re_encoded_sequence)
+
+ return prompt, total_input_len
+
+
+# -----------------------------------------------------------------------------
+# MultiModalDataset Implementation
+# -----------------------------------------------------------------------------
+
+class RandomMultiModalDataset(RandomDataset):
+ """
+ Synthetic multimodal dataset (text + images) that extends RandomDataset.
+
+ Status:
+ - Images: supported via synthetic RGB data.
+ - Video: not yet supported (TODO: implement video generation method).
+ - Audio: not yet supported.
+
+ Sampling overview:
+ 1) Number of items per request is sampled uniformly from the integer range
+ [floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is
+ `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
+ The maximum is further clamped to the sum of per-modality limits.
+ 2) Each item’s modality and shape is sampled from `bucket_config`, a dict
+ mapping (height, width, num_frames) → probability. We treat
+ `num_frames`=1 as image and and `num_frames` > 1 as video.
+ Entries with zero probability are removed and the rest are renormalized
+ to sum to 1.
+ 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
+ When a modality reaches its cap, all of its buckets are excluded and the
+ remaining probabilities are renormalized.
+
+ Example bucket configuration:
+ {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
+ - Two image buckets (`num_frames`=1) and one video bucket
+ (`num_frames`=16).
+ OBS.: Only image sampling is supported for now.
+ """
+
+ IS_MULTIMODAL = True
+ # NOTE: video sampling is WIP. Setting it to 0.
+ DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0}
+
+ DEFAULT_BASE_ITEMS_PER_REQUEST = 1
+ DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
+ DEFAULT_MM_ITEM_BUCKET_CONFIG = {
+ (256, 256, 1): 0.5,
+ (720, 1280, 1): 0.5,
+ (720, 1280, 16): 0.0,
+ }
+ DEFAULT_ENABLE_MULTIMODAL_CHAT = False
+
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+
+ def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
+ """Generate synthetic PIL image with random RGB values.
+
+ NOTE: iid pixel sampling results in worst-case compression
+ (good for stressing I/O), but very unlike real photos.
+ We could consider a “low-freq” mode (e.g., noise blur)
+ to emulate network realism instead of max stress.
+ """
+ random_pixels = self._rng.integers(
+ 0,
+ 256,
+ (height, width, 3),
+ dtype=np.uint8,
+ )
+ return Image.fromarray(random_pixels)
+
+ def generate_synthetic_video(self, width: int,
+ height: int,
+ num_frames: int) -> Any:
+ """Generate synthetic video with random values.
+
+ TODO: Finish this method.
+ """
+ raise NotImplementedError("Video sampling is WIP.")
+
+ def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
+ """Map the configuration to the modality."""
+ if config[-1] == 1:
+ return "image"
+ elif config[-1] > 1:
+ return "video"
+ else:
+ raise ValueError(f"Invalid multimodal item configuration: {config}")
+
+ def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
+ float]) -> dict[tuple[int, int, int], float]:
+ """
+ Remove zero probability entries
+ and normalize the bucket config to sum to 1.
+ """
+ # Raise error if value is negative
+ if any(v < 0 for v in bucket_config.values()):
+ raise ValueError("Bucket config values must be non-negative.")
+ # Remove zero probability entries
+ bucket_config = {k: v for k, v in bucket_config.items() if v > 0}
+ # if bucket config is empty, raise error
+ if not bucket_config:
+ raise ValueError("Got invalid bucket config. "
+ "Bucket config values must be non-zero.")
+ # Normalize the remaining bucket config to sum to 1
+ total = sum(bucket_config.values())
+ return {k: v / total for k, v in bucket_config.items()}
+
+
+ def generate_mm_item(self,
+ mm_item_config: tuple[int, int, int],
+ ) -> Mapping[str, Any]:
+ """
+ Create synthetic images and videos and
+ apply process_image/process_video respectively.
+ This follows the OpenAI API chat completions
+ https://github.com/openai/openai-python
+ """
+
+ if self.map_config_to_modality(mm_item_config) == "image":
+ return process_image(self.generate_synthetic_image(
+ mm_item_config[1],
+ mm_item_config[0]))
+ elif self.map_config_to_modality(mm_item_config) == "video":
+ return process_video(self.generate_synthetic_video(
+ mm_item_config[1],
+ mm_item_config[0],
+ mm_item_config[2]))
+ else:
+ raise ValueError(f"Invalid multimodal item configuration: "
+ f"{mm_item_config}")
+
+
+ def get_mm_item_sampling_params(
+ self,
+ base_items_per_request: int,
+ num_mm_items_range_ratio: float,
+ limit_mm_per_prompt: dict[str, int],
+ bucket_config: dict[tuple[int, int, int], float],
+ ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]:
+ """
+ Get the sampling parameters for the multimodal items.
+ """
+ # Enforce num_mm_items_range_ratio <= 1
+ if not (0.0 <= num_mm_items_range_ratio <= 1.0):
+ raise ValueError("num_mm_items_range_ratio must be in [0, 1].")
+
+ # Ensure modalities to sample are in limit_mm_per_prompt
+ for k, v in bucket_config.items():
+ # get modality from bucket config
+ modality = self.map_config_to_modality(k)
+ if modality not in limit_mm_per_prompt:
+ raise ValueError(f"Modality {modality} is not in "
+ f"limit_mm_per_prompt: "
+ f"{limit_mm_per_prompt.keys()}")
+
+ # Remove zero probability entries
+ # and normalize bucket config to sum to 1
+ bucket_config = self.normalize_bucket_config(bucket_config)
+ logger.info(
+ "Normalized bucket config: %s", bucket_config,
+ )
+ # Only consider limit per prompt for modalities in bucket config
+ allowed_modalities = {self.map_config_to_modality(cfg)
+ for cfg in bucket_config}
+ limit_mm_per_prompt = {
+ k: v for k, v in limit_mm_per_prompt.items()
+ if k in allowed_modalities}
+ if not limit_mm_per_prompt:
+ raise ValueError("No valid limits for modalities present in "
+ "bucket_config.")
+
+ logger.info(
+ "Updated mm-limit-per-prompt: %s", limit_mm_per_prompt,
+ )
+
+ # Get max and min num mm items and ensure
+ # it is at most the sum of limit_mm_per_prompt for all modalities
+ max_num_mm_items = min(
+ sum(limit_mm_per_prompt.values()),
+ math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
+ )
+ # Ensure min num mm items is at least 0
+ min_num_mm_items = max(
+ 0,
+ math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
+ )
+ # Raise error if min num mm items is greater than max num mm items
+ if min_num_mm_items > max_num_mm_items:
+ raise ValueError(f"Min num mm items is greater than max mm items: "
+ f"{min_num_mm_items} > {max_num_mm_items}")
+
+ logger.info(
+ "Sampling number of multimodal items from [%s, %s]",
+ min_num_mm_items, max_num_mm_items,
+ )
+
+ return (
+ min_num_mm_items,
+ max_num_mm_items,
+ limit_mm_per_prompt,
+ bucket_config,
+ )
+
+ def get_mm_item_iterator(
+ self,
+ min_num_mm_items: int,
+ max_num_mm_items: int,
+ bucket_config: dict[tuple[int, int, int], float],
+ limit_mm_per_prompt: dict[str, int],
+ ) -> Iterator[tuple[int,int, int]]:
+ """
+ Iterator over the multimodal items for each request
+ whose size is between min_num_mm_items and max_num_mm_items.
+
+ Loop over the bucket config and sample a multimodal item.
+ Loop until the number of multimodal items sampled is equal to
+ request_num_mm_items or limit of multimodal items per prompt
+ for all modalities is reached.
+
+ Note:
+ - This function operates on a per-request shallow copy of
+ `bucket_config` (tuple->float). The original dict passed to
+ `sample` is not mutated. If this ever changes, a test
+ is implemented and will fail.
+ """
+ # Get the number of multimodal items to sample
+ request_num_mm_items = int(
+ self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
+ )
+ # If request_num_mm_items is 0, yield an empty iterator
+ if request_num_mm_items == 0:
+ return
+ # Initialize modality counters
+ modality_counter = {self.map_config_to_modality(k): 0
+ for k in bucket_config}
+ # Copy the bucket config to avoid modifying the original
+ bucket_config_copy = bucket_config.copy()
+ # Loop over the number of multimodal items to sample
+ while sum(modality_counter.values()) < request_num_mm_items:
+ # Sample a multimodal item config
+ mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
+ p=list(bucket_config_copy.values()))
+ modality = self.map_config_to_modality(mm_item_config)
+ # Check that modality count is less than limit per prompt
+ if modality_counter[modality] < limit_mm_per_prompt[modality]:
+ modality_counter[modality] += 1
+ yield (
+ mm_item_config
+ )
+ else:
+ # If the counter is greater than the limit per prompt
+ # set all multimodal items of this modality to 0
+ for k, v in bucket_config_copy.items():
+ if self.map_config_to_modality(k) == modality:
+ bucket_config_copy[k] = 0
+ # If all configs are 0, break the loop
+ # This should not happen as request_num_mm_items is at most
+ # the sum of limit_mm_per_prompt for all modalities
+ if all(v == 0 for v in bucket_config_copy.values()):
+ logger.warning("Exhausted all multimodal items "
+ "of modality %s",
+ modality)
+ break
+ # Renormalize the bucket config
+ bucket_config_copy = self.normalize_bucket_config(
+ bucket_config_copy)
+
+
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ request_id_prefix: str = "",
+ prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
+ range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
+ input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
+ output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN,
+ limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
+ base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
+ num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
+ bucket_config: dict[tuple[int, int, int], float] =
+ DEFAULT_MM_ITEM_BUCKET_CONFIG,
+ enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
+ **kwargs,
+ ) -> list[SampleRequest]:
+
+ # NOTE: Video sampling is WIP. Raise error if video is in bucket config
+ # and probability is non-zero.
+ if any(self.map_config_to_modality(cfg) == "video" and p > 0
+ for cfg, p in bucket_config.items()):
+ raise NotImplementedError("Video sampling not implemented; "
+ "set its probability to 0.")
+
+ # Get the sampling parameters for the dataset
+ input_lens, output_lens, offsets = self.get_sampling_params(
+ num_requests, range_ratio, input_len, output_len, tokenizer
+ )
+
+ (
+ min_num_mm_items,
+ max_num_mm_items,
+ limit_mm_per_prompt,
+ bucket_config,
+ ) = self.get_mm_item_sampling_params(
+ base_items_per_request,
+ num_mm_items_range_ratio,
+ limit_mm_per_prompt,
+ bucket_config,
+ )
+
+ # Generate prefix once
+ prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
+ vocab_size = tokenizer.vocab_size
+ # Add synthetic multimodal items to each request
+ mm_requests = []
+ for i in range(num_requests):
+ prompt, total_input_len = self.generate_token_sequence(
+ tokenizer=tokenizer,
+ prefix_token_ids=prefix_token_ids,
+ prefix_len=prefix_len,
+ vocab_size=vocab_size,
+ input_len=int(input_lens[i]),
+ offset=int(offsets[i]),
+ index=i,
+ )
+ # Get multimodal item iterator for a given request
+ mm_item_iterator = self.get_mm_item_iterator(
+ min_num_mm_items,
+ max_num_mm_items,
+ bucket_config,
+ limit_mm_per_prompt,
+ )
+
+ mm_content = cast(list[dict[str, Any]], [
+ self.generate_mm_item(mm_item_config)
+ for mm_item_config in mm_item_iterator
+ ])
+
+ if enable_multimodal_chat:
+ # NOTE: For now this option is only provided for completeness
+ # given that the serve.py benchmark currently does not use it.
+ mm_chat_prompt: Any = prompt
+ mm_chat_prompt = self.apply_multimodal_chat_transformation(
+ prompt, mm_content)
+ sample_request = SampleRequest(
+ prompt=mm_chat_prompt,
+ prompt_len=total_input_len,
+ expected_output_len=int(output_lens[i]),
+ multi_modal_data=None,
+ request_id=request_id_prefix + str(i),
+ )
+ else:
+ sample_request = SampleRequest(
+ prompt=prompt,
+ prompt_len=total_input_len,
+ expected_output_len=int(output_lens[i]),
+ multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(i),
+ )
+ mm_requests.append(sample_request)
+ return mm_requests
# -----------------------------------------------------------------------------
# ShareGPT Dataset Implementation
@@ -545,8 +1004,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
type=str,
default="random",
choices=[
- "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
- "prefix_repetition"
+ "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
+ "custom", "prefix_repetition"
],
help="Name of the dataset to benchmark on.",
)
@@ -647,6 +1106,98 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
"input_len * (1 + range_ratio)]."),
)
+ # random multimodal dataset options
+ random_mm_group = parser.add_argument_group(
+ "random multimodal dataset options extended from random dataset")
+ random_mm_group.add_argument(
+ "--random-mm-base-items-per-request",
+ type=int,
+ default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
+ help=(
+ "Base number of multimodal items per request for random-mm. "
+ "Actual per-request count is sampled around this base using "
+ "--random-mm-num-mm-items-range-ratio."
+ ),
+ )
+ random_mm_group.add_argument(
+ "--random-mm-num-mm-items-range-ratio",
+ type=float,
+ default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
+ help=(
+ "Range ratio r in [0, 1] for sampling items per request. "
+ "We sample uniformly from the closed integer range "
+ "[floor(n*(1-r)), ceil(n*(1+r))] "
+ "where n is the base items per request. "
+ "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped "
+ "to the sum of per-modality limits from "
+ "--random-mm-limit-mm-per-prompt. "
+ "An error is raised if the computed min exceeds the max."
+ ),
+ )
+ random_mm_group.add_argument(
+ "--random-mm-limit-mm-per-prompt",
+ type=json.loads,
+ default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
+ help=(
+ "Per-modality hard caps for items attached per request, e.g. "
+ "'{\"image\": 3, \"video\": 0}'. The sampled per-request item "
+ "count is clamped to the sum of these limits. When a modality "
+ "reaches its cap, its buckets are excluded and probabilities are "
+ "renormalized."
+ "OBS.: Only image sampling is supported for now."
+ ),
+ )
+
+ def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]:
+ # If already a dict (e.g., programmatic call), normalize keys
+ def normalize(d: dict) -> dict[tuple[int, int, int], float]:
+ out: dict[tuple[int, int, int], float] = {}
+ for k, val in d.items():
+ key = k
+ if isinstance(key, str):
+ with suppress(Exception):
+ key = ast.literal_eval(key)
+ if not (isinstance(key, tuple) and len(key) == 3
+ and all(isinstance(x, int) for x in key)):
+ raise ValueError(
+ f"Invalid bucket key {k!r}. Expected tuple (H, W, T)."
+ )
+ out[(int(key[0]), int(key[1]), int(key[2]))] = float(val)
+ return out
+
+ if isinstance(v, dict):
+ return normalize(v)
+ if isinstance(v, str):
+ # Python literal (supports tuple keys)
+ parsed = ast.literal_eval(v)
+ if not isinstance(parsed, dict):
+ raise ValueError("Bucket config must parse to a dict.")
+ return normalize(parsed)
+ raise ValueError("Unsupported value for --random-mm-bucket-config.")
+
+ random_mm_group.add_argument(
+ "--random-mm-bucket-config",
+ type=_parse_mm_bucket_config,
+ default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
+ help=(
+ "The bucket config is a dictionary mapping a multimodal item"
+ "sampling configuration to a probability."
+ "Currently allows for 2 modalities: images and videos. "
+ "An bucket key is a tuple of (height, width, num_frames)"
+ "The value is the probability of sampling that specific item. "
+ "Example: "
+ "--random-mm-bucket-config "
+ "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} "
+ "First item: images with resolution 256x256 w.p. 0.5"
+ "Second item: images with resolution 720x1280 w.p. 0.4 "
+ "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1"
+ "OBS.: If the probabilities do not sum to 1, they are normalized."
+ "OBS bis.: Only image sampling is supported for now."
+ ),
+ )
+
+
+
hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset",
type=str,
@@ -821,6 +1372,22 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
range_ratio=args.random_range_ratio,
request_id_prefix=args.request_id_prefix,
),
+ "random-mm":
+ lambda: RandomMultiModalDataset(
+ random_seed=args.seed, dataset_path=args.dataset_path
+ ).sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ prefix_len=args.random_prefix_len,
+ range_ratio=args.random_range_ratio,
+ input_len=args.random_input_len,
+ output_len=args.random_output_len,
+ base_items_per_request=args.random_mm_base_items_per_request,
+ limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
+ num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
+ bucket_config=args.random_mm_bucket_config,
+ request_id_prefix=args.request_id_prefix,
+ ),
"prefix_repetition":
lambda: PrefixRepetitionRandomDataset(
random_seed=args.seed, dataset_path=args.dataset_path
@@ -836,6 +1403,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
}
try:
+ # Enforce endpoint compatibility for multimodal datasets.
+ if args.dataset_name == "random-mm" and args.endpoint_type not in [
+ "openai-chat"]:
+ raise ValueError(
+ "Multi-modal content (images) is only supported on "
+ "'openai-chat' backend."
+ )
input_requests = dataset_mapping[args.dataset_name]()
except KeyError as err:
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err