mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 00:27:01 +08:00
Feature/benchmark/random mm data/images (#23119)
Signed-off-by: breno.skuk <breno.skuk@hcompany.ai>
This commit is contained in:
parent
2da02dd0d8
commit
0cb7b065c3
@ -59,6 +59,12 @@ become available.
|
|||||||
<td style="text-align: center;">✅</td>
|
<td style="text-align: center;">✅</td>
|
||||||
<td><code>synthetic</code></td>
|
<td><code>synthetic</code></td>
|
||||||
</tr>
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><strong>RandomMultiModal (Image/Video)</strong></td>
|
||||||
|
<td style="text-align: center;">🟡</td>
|
||||||
|
<td style="text-align: center;">🚧</td>
|
||||||
|
<td><code>synthetic</code> </td>
|
||||||
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
<td><strong>Prefix Repetition</strong></td>
|
<td><strong>Prefix Repetition</strong></td>
|
||||||
<td style="text-align: center;">✅</td>
|
<td style="text-align: center;">✅</td>
|
||||||
@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \
|
|||||||
--endpoint /v1/chat/completion
|
--endpoint /v1/chat/completion
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Synthetic Random Images (random-mm)
|
||||||
|
|
||||||
|
Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
|
||||||
|
- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
|
||||||
|
- Video sampling is not yet implemented.
|
||||||
|
|
||||||
|
Start the server (example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
|
||||||
|
--dtype bfloat16 \
|
||||||
|
--max-model-len 16384 \
|
||||||
|
--limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||||
|
--mm-processor-kwargs max_pixels=1003520
|
||||||
|
```
|
||||||
|
|
||||||
|
Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
|
||||||
|
|
||||||
|
Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm bench serve \
|
||||||
|
--backend openai-chat \
|
||||||
|
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
||||||
|
--endpoint /v1/chat/completions \
|
||||||
|
--dataset-name random-mm \
|
||||||
|
--num-prompts 100 \
|
||||||
|
--max-concurrency 10 \
|
||||||
|
--random-prefix-len 25 \
|
||||||
|
--random-input-len 300 \
|
||||||
|
--random-output-len 40 \
|
||||||
|
--random-range-ratio 0.2 \
|
||||||
|
--random-mm-base-items-per-request 2 \
|
||||||
|
--random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
|
||||||
|
--random-mm-bucket-config '{(224, 224, 1): 1.0}' \
|
||||||
|
--request-rate inf \
|
||||||
|
--ignore-eos \
|
||||||
|
--seed 42
|
||||||
|
```
|
||||||
|
|
||||||
|
The number of items per request can be controlled by passing multiple image buckets:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
--random-mm-base-items-per-request 2 \
|
||||||
|
--random-mm-num-mm-items-range-ratio 0.5 \
|
||||||
|
--random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
|
||||||
|
--random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
|
||||||
|
```
|
||||||
|
|
||||||
|
Flags specific to `random-mm`:
|
||||||
|
|
||||||
|
- `--random-mm-base-items-per-request`: base number of multimodal items per request.
|
||||||
|
- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
|
||||||
|
- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
|
||||||
|
- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
|
||||||
|
|
||||||
|
Behavioral notes:
|
||||||
|
|
||||||
|
- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
|
||||||
|
|
||||||
|
How sampling works:
|
||||||
|
|
||||||
|
- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
|
||||||
|
- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
|
||||||
|
- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
|
||||||
|
This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
|
||||||
|
- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|||||||
344
tests/benchmarks/test_random_dataset.py
Normal file
344
tests/benchmarks/test_random_dataset.py
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import random
|
||||||
|
from typing import Any, NamedTuple, Optional, cast
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset,
|
||||||
|
SampleRequest)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def hf_tokenizer() -> PreTrainedTokenizerBase:
|
||||||
|
# Use a small, commonly available tokenizer
|
||||||
|
return AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
|
||||||
|
class Params(NamedTuple):
|
||||||
|
num_requests: int
|
||||||
|
prefix_len: int
|
||||||
|
range_ratio: float
|
||||||
|
input_len: int
|
||||||
|
output_len: int
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def random_dataset_params() -> Params:
|
||||||
|
return Params(num_requests=16,
|
||||||
|
prefix_len=7,
|
||||||
|
range_ratio=0.3,
|
||||||
|
input_len=50,
|
||||||
|
output_len=20)
|
||||||
|
|
||||||
|
|
||||||
|
def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]:
|
||||||
|
"""Project a SampleRequest into a comparable tuple."""
|
||||||
|
return (req.prompt, req.prompt_len, req.expected_output_len)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_samples(dataset: RandomDataset,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int = 16,
|
||||||
|
prefix_len: int = 7,
|
||||||
|
range_ratio: float = 0.3,
|
||||||
|
input_len: int = 50,
|
||||||
|
output_len: int = 20) -> list[tuple[str, int, int]]:
|
||||||
|
samples = dataset.sample(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_requests=num_requests,
|
||||||
|
prefix_len=prefix_len,
|
||||||
|
range_ratio=range_ratio,
|
||||||
|
input_len=input_len,
|
||||||
|
output_len=output_len,
|
||||||
|
)
|
||||||
|
return [_fingerprint_sample(s) for s in samples]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_dataset_same_seed(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
|
random_dataset_params: Params) -> None:
|
||||||
|
"""Same seed should yield identical outputs, even if global RNGs change.
|
||||||
|
|
||||||
|
This guards against accidental reliance on Python's random or np.random
|
||||||
|
in RandomDataset after moving to numpy.default_rng.
|
||||||
|
"""
|
||||||
|
p = random_dataset_params
|
||||||
|
common_seed = 123
|
||||||
|
dataset_a = RandomDataset(random_seed=common_seed)
|
||||||
|
dataset_b = RandomDataset(random_seed=common_seed)
|
||||||
|
a = _collect_samples(dataset_a,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=p.num_requests,
|
||||||
|
prefix_len=p.prefix_len,
|
||||||
|
range_ratio=p.range_ratio,
|
||||||
|
input_len=p.input_len,
|
||||||
|
output_len=p.output_len)
|
||||||
|
|
||||||
|
# Perturb global RNG state to ensure isolation
|
||||||
|
random.seed(999)
|
||||||
|
_ = [random.random() for _ in range(100)]
|
||||||
|
np.random.seed(888)
|
||||||
|
_ = [np.random.random() for _ in range(100)]
|
||||||
|
|
||||||
|
b = _collect_samples(dataset_b,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=p.num_requests,
|
||||||
|
prefix_len=p.prefix_len,
|
||||||
|
range_ratio=p.range_ratio,
|
||||||
|
input_len=p.input_len,
|
||||||
|
output_len=p.output_len)
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_dataset_different_seeds(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
|
random_dataset_params: Params) -> None:
|
||||||
|
"""Different seeds should change outputs with overwhelming likelihood."""
|
||||||
|
p = random_dataset_params
|
||||||
|
seed_a = 0
|
||||||
|
dataset_a = RandomDataset(random_seed=seed_a)
|
||||||
|
a = _collect_samples(dataset_a,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=p.num_requests,
|
||||||
|
prefix_len=p.prefix_len,
|
||||||
|
range_ratio=p.range_ratio,
|
||||||
|
input_len=p.input_len,
|
||||||
|
output_len=p.output_len)
|
||||||
|
|
||||||
|
seed_b = 999
|
||||||
|
dataset_b = RandomDataset(random_seed=seed_b)
|
||||||
|
# Perturb global RNG with same seed as dataset_a to ensure isolation
|
||||||
|
random.seed(seed_a)
|
||||||
|
np.random.seed(seed_a)
|
||||||
|
b = _collect_samples(dataset_b,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=p.num_requests,
|
||||||
|
prefix_len=p.prefix_len,
|
||||||
|
range_ratio=p.range_ratio,
|
||||||
|
input_len=p.input_len,
|
||||||
|
output_len=p.output_len)
|
||||||
|
assert a != b
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# RandomMultiModalDataset tests
|
||||||
|
# -----------------------------
|
||||||
|
|
||||||
|
def _mm_fingerprint_sample(
|
||||||
|
req: SampleRequest,
|
||||||
|
) -> tuple[str, int, int, int, list[str]]:
|
||||||
|
"""Create a compact fingerprint for multimodal samples.
|
||||||
|
|
||||||
|
Includes:
|
||||||
|
- prompt string
|
||||||
|
- prompt_len
|
||||||
|
- expected_output_len
|
||||||
|
- count of multimodal items
|
||||||
|
- per-item type and URL prefix (e.g., 'data:image/jpeg;base64,')
|
||||||
|
"""
|
||||||
|
items = req.multi_modal_data or []
|
||||||
|
item_prefixes: list[str] = []
|
||||||
|
for it in items:
|
||||||
|
if isinstance(it, dict) and it.get("type") == "image_url":
|
||||||
|
url = it.get("image_url", {}).get("url", "")
|
||||||
|
# Only keep a short identifying prefix to avoid huge strings
|
||||||
|
item_prefixes.append(f"image:{url[:22]}")
|
||||||
|
elif isinstance(it, dict) and it.get("type") == "video_url":
|
||||||
|
url = it.get("video_url", {}).get("url", "")
|
||||||
|
item_prefixes.append(f"video:{url[:22]}")
|
||||||
|
else:
|
||||||
|
item_prefixes.append("unknown:")
|
||||||
|
return (req.prompt, req.prompt_len, req.expected_output_len, len(items),
|
||||||
|
item_prefixes)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_mm_samples(
|
||||||
|
dataset: RandomMultiModalDataset,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
*,
|
||||||
|
num_requests: int = 8,
|
||||||
|
prefix_len: int = 3,
|
||||||
|
range_ratio: float = 0.0,
|
||||||
|
input_len: int = 20,
|
||||||
|
output_len: int = 5,
|
||||||
|
base_items_per_request: int = 2,
|
||||||
|
num_mm_items_range_ratio: float = 0.0,
|
||||||
|
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||||
|
bucket_config: Optional[dict[tuple[int, int, int], float]] = None,
|
||||||
|
enable_multimodal_chat: bool = False,
|
||||||
|
) -> list[SampleRequest]:
|
||||||
|
if limit_mm_per_prompt is None:
|
||||||
|
limit_mm_per_prompt = {"image": 5, "video": 0}
|
||||||
|
if bucket_config is None:
|
||||||
|
bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5}
|
||||||
|
return dataset.sample(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_requests=num_requests,
|
||||||
|
prefix_len=prefix_len,
|
||||||
|
range_ratio=range_ratio,
|
||||||
|
input_len=input_len,
|
||||||
|
output_len=output_len,
|
||||||
|
base_items_per_request=base_items_per_request,
|
||||||
|
num_mm_items_range_ratio=num_mm_items_range_ratio,
|
||||||
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
|
bucket_config=bucket_config,
|
||||||
|
enable_multimodal_chat=enable_multimodal_chat,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||||
|
seed = 42
|
||||||
|
ds_a = RandomMultiModalDataset(random_seed=seed)
|
||||||
|
ds_b = RandomMultiModalDataset(random_seed=seed)
|
||||||
|
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||||
|
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||||
|
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||||
|
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||||
|
assert fa == fb
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_different_seeds(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> None:
|
||||||
|
ds_a = RandomMultiModalDataset(random_seed=0)
|
||||||
|
ds_b = RandomMultiModalDataset(random_seed=999)
|
||||||
|
a = _collect_mm_samples(ds_a, hf_tokenizer)
|
||||||
|
b = _collect_mm_samples(ds_b, hf_tokenizer)
|
||||||
|
fa = [_mm_fingerprint_sample(s) for s in a]
|
||||||
|
fb = [_mm_fingerprint_sample(s) for s in b]
|
||||||
|
assert fa != fb
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_respects_limits(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> None:
|
||||||
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
|
# Requesting 3 items with a per-prompt limit of 1 should error per current
|
||||||
|
# design (dataset refuses to silently clamp below the requested baseline).
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_collect_mm_samples(
|
||||||
|
ds,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=12,
|
||||||
|
base_items_per_request=3,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||||
|
bucket_config={(32, 32, 1): 1.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_zero_prob_entries_are_removed(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> None:
|
||||||
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
|
# Second bucket has zero probability and should be ignored after
|
||||||
|
# normalization
|
||||||
|
samples = _collect_mm_samples(
|
||||||
|
ds,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=6,
|
||||||
|
base_items_per_request=2,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt={"image": 10, "video": 0},
|
||||||
|
bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0},
|
||||||
|
)
|
||||||
|
for s in samples:
|
||||||
|
assert isinstance(s.multi_modal_data, list)
|
||||||
|
typed_mm = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||||
|
for it in typed_mm:
|
||||||
|
assert it.get("type") == "image_url"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||||
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
|
samples = _collect_mm_samples(
|
||||||
|
ds,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=5,
|
||||||
|
base_items_per_request=0,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt={"image": 5, "video": 0},
|
||||||
|
bucket_config={(32, 32, 1): 1.0},
|
||||||
|
)
|
||||||
|
for s in samples:
|
||||||
|
assert s.multi_modal_data == []
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_num_items_per_prompt(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase) -> None:
|
||||||
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
|
# Fixed number of images per prompt
|
||||||
|
# set num_mm_items_range_ratio to 0.0
|
||||||
|
# TODO: modify video values when video sampling is implemented
|
||||||
|
samples_fixed_items = _collect_mm_samples(
|
||||||
|
ds,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=5,
|
||||||
|
base_items_per_request=3,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt={"image": 3, "video": 0},
|
||||||
|
bucket_config={(32, 32, 1): 1.0},
|
||||||
|
)
|
||||||
|
# Must have 5 requests each with 3 mm items per prompt
|
||||||
|
assert len(samples_fixed_items) == 5
|
||||||
|
for s in samples_fixed_items:
|
||||||
|
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||||
|
assert len(mm_data) == 3
|
||||||
|
for it in mm_data:
|
||||||
|
assert it.get("type") == "image_url"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_random_mm_bucket_config_not_mutated(
|
||||||
|
hf_tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
ds = RandomMultiModalDataset(random_seed=0)
|
||||||
|
# This bucket config is not normalized to sum to 1
|
||||||
|
# and has more buckets than requested images
|
||||||
|
original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3}
|
||||||
|
# Keep a snapshot to compare after sampling
|
||||||
|
snapshot = dict(original)
|
||||||
|
|
||||||
|
_ = _collect_mm_samples(
|
||||||
|
ds,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=4,
|
||||||
|
base_items_per_request=1,
|
||||||
|
num_mm_items_range_ratio=0.0,
|
||||||
|
limit_mm_per_prompt={"image": 1, "video": 0},
|
||||||
|
bucket_config=original,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure the original dict content is unchanged
|
||||||
|
assert original == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
# Vary number of mm items per prompt
|
||||||
|
# set num_mm_items_range_ratio to 0.5
|
||||||
|
samples_varying_items = _collect_mm_samples(
|
||||||
|
ds,
|
||||||
|
hf_tokenizer,
|
||||||
|
num_requests=5,
|
||||||
|
base_items_per_request=2,
|
||||||
|
num_mm_items_range_ratio=0.5,
|
||||||
|
limit_mm_per_prompt={"image": 4, "video": 0},
|
||||||
|
bucket_config={(32, 32, 1): 1.0},
|
||||||
|
)
|
||||||
|
# Must have 5 requests each with less than 4 mm items per prompt
|
||||||
|
# but at least 1 mm item per prompt
|
||||||
|
assert len(samples_varying_items) == 5
|
||||||
|
for s in samples_varying_items:
|
||||||
|
mm_data = cast(list[dict[str, Any]], s.multi_modal_data)
|
||||||
|
assert len(mm_data) <= 4
|
||||||
|
assert len(mm_data) >= 1
|
||||||
|
for it in mm_data:
|
||||||
|
assert it.get("type") == "image_url"
|
||||||
@ -11,18 +11,21 @@ generation. Supported dataset types include:
|
|||||||
- HuggingFace
|
- HuggingFace
|
||||||
- VisionArena
|
- VisionArena
|
||||||
"""
|
"""
|
||||||
|
import ast
|
||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import random
|
import random
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping
|
from collections.abc import Iterator, Mapping
|
||||||
|
from contextlib import suppress
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -114,7 +117,9 @@ class BenchmarkDataset(ABC):
|
|||||||
def apply_multimodal_chat_transformation(
|
def apply_multimodal_chat_transformation(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
mm_content: Optional[MultiModalDataDict] = None) -> list[dict]:
|
mm_content: Optional[
|
||||||
|
Union[MultiModalDataDict, dict, list[dict]]
|
||||||
|
] = None) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
Transform a prompt and optional multimodal content into a chat format.
|
Transform a prompt and optional multimodal content into a chat format.
|
||||||
This method is used for chat models that expect a specific conversation
|
This method is used for chat models that expect a specific conversation
|
||||||
@ -122,7 +127,15 @@ class BenchmarkDataset(ABC):
|
|||||||
"""
|
"""
|
||||||
content = [{"text": prompt, "type": "text"}]
|
content = [{"text": prompt, "type": "text"}]
|
||||||
if mm_content is not None:
|
if mm_content is not None:
|
||||||
content.append(mm_content)
|
if isinstance(mm_content, list):
|
||||||
|
content.extend(cast(list[dict[str, Any]], mm_content))
|
||||||
|
elif isinstance(mm_content, dict):
|
||||||
|
content.append(mm_content)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"Could not process multimodal content of type: " +
|
||||||
|
f"{type(mm_content)}"
|
||||||
|
)
|
||||||
return [{"role": "user", "content": content}]
|
return [{"role": "user", "content": content}]
|
||||||
|
|
||||||
def load_data(self) -> None:
|
def load_data(self) -> None:
|
||||||
@ -362,90 +375,536 @@ def process_video(video: Any) -> Mapping[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
class RandomDataset(BenchmarkDataset):
|
class RandomDataset(BenchmarkDataset):
|
||||||
|
"""
|
||||||
|
Synthetic text-only dataset for serving/throughput benchmarks.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- Sample input/output token lengths per request from integer-uniform ranges
|
||||||
|
around configured means (controlled by range_ratio).
|
||||||
|
- Prepend a fixed random prefix of length prefix_len.
|
||||||
|
- Generate the remaining tokens as a reproducible sequence:
|
||||||
|
(offset + index + arange(input_len)) % vocab_size.
|
||||||
|
- Decode then re-encode/truncate to ensure prompt token counts match.
|
||||||
|
- Uses numpy.default_rng seeded with random_seed for reproducible sampling.
|
||||||
|
"""
|
||||||
# Default values copied from benchmark_serving.py for the random dataset.
|
# Default values copied from benchmark_serving.py for the random dataset.
|
||||||
DEFAULT_PREFIX_LEN = 0
|
DEFAULT_PREFIX_LEN = 0
|
||||||
DEFAULT_RANGE_RATIO = 0.0
|
DEFAULT_RANGE_RATIO = 0.0
|
||||||
DEFAULT_INPUT_LEN = 1024
|
DEFAULT_INPUT_LEN = 1024
|
||||||
DEFAULT_OUTPUT_LEN = 128
|
DEFAULT_OUTPUT_LEN = 128
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, **kwargs) -> None:
|
||||||
self,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
random.seed(self.random_seed)
|
# Use numpy's default_rng for deterministic sampling
|
||||||
np.random.seed(self.random_seed)
|
# Do not use random.seed() or np.random.seed() elsewhere in this class.
|
||||||
|
# This ensures that the RNG is isolated from global RNG state.
|
||||||
|
self._rng = np.random.default_rng(self.random_seed)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
self,
|
self,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
num_requests: int,
|
num_requests: int,
|
||||||
|
request_id_prefix: str = "",
|
||||||
prefix_len: int = DEFAULT_PREFIX_LEN,
|
prefix_len: int = DEFAULT_PREFIX_LEN,
|
||||||
range_ratio: float = DEFAULT_RANGE_RATIO,
|
range_ratio: float = DEFAULT_RANGE_RATIO,
|
||||||
input_len: int = DEFAULT_INPUT_LEN,
|
input_len: int = DEFAULT_INPUT_LEN,
|
||||||
output_len: int = DEFAULT_OUTPUT_LEN,
|
output_len: int = DEFAULT_OUTPUT_LEN,
|
||||||
request_id_prefix: str = "",
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> list[SampleRequest]:
|
) -> list[SampleRequest]:
|
||||||
# Enforce range_ratio < 1
|
|
||||||
assert range_ratio < 1.0, (
|
input_lens, output_lens, offsets = self.get_sampling_params(
|
||||||
"random_range_ratio must be < 1.0 to ensure a valid sampling range"
|
num_requests, range_ratio, input_len, output_len, tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Generate prefix once
|
||||||
|
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
|
||||||
vocab_size = tokenizer.vocab_size
|
vocab_size = tokenizer.vocab_size
|
||||||
num_special_tokens = tokenizer.num_special_tokens_to_add()
|
|
||||||
real_input_len = input_len - num_special_tokens
|
|
||||||
|
|
||||||
prefix_token_ids = (np.random.randint(
|
|
||||||
0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else [])
|
|
||||||
|
|
||||||
# New sampling logic: [X * (1 - b), X * (1 + b)]
|
|
||||||
input_low = int(real_input_len * (1 - range_ratio))
|
|
||||||
input_high = int(real_input_len * (1 + range_ratio))
|
|
||||||
output_low = int(output_len * (1 - range_ratio))
|
|
||||||
output_high = int(output_len * (1 + range_ratio))
|
|
||||||
|
|
||||||
# Add logging for debugging
|
|
||||||
logger.info(
|
|
||||||
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
|
|
||||||
input_low, input_high, output_low, output_high)
|
|
||||||
|
|
||||||
input_lens = np.random.randint(input_low,
|
|
||||||
input_high + 1,
|
|
||||||
size=num_requests)
|
|
||||||
output_lens = np.random.randint(output_low,
|
|
||||||
output_high + 1,
|
|
||||||
size=num_requests)
|
|
||||||
offsets = np.random.randint(0, vocab_size, size=num_requests)
|
|
||||||
|
|
||||||
requests = []
|
requests = []
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) %
|
prompt, total_input_len = self.generate_token_sequence(
|
||||||
vocab_size).tolist()
|
tokenizer=tokenizer,
|
||||||
token_sequence = prefix_token_ids + inner_seq
|
prefix_token_ids=prefix_token_ids,
|
||||||
prompt = tokenizer.decode(token_sequence)
|
prefix_len=prefix_len,
|
||||||
# After decoding the prompt we have to encode and decode it again.
|
vocab_size=vocab_size,
|
||||||
# This is done because in some cases N consecutive tokens
|
input_len=int(input_lens[i]),
|
||||||
# give a string tokenized into != N number of tokens.
|
offset=int(offsets[i]),
|
||||||
# For example for GPT2Tokenizer:
|
index=i,
|
||||||
# [6880, 6881] -> ['Ġcalls', 'here'] ->
|
)
|
||||||
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
|
||||||
# To avoid uncontrolled change of the prompt length,
|
|
||||||
# the encoded sequence is truncated before being decode again.
|
|
||||||
total_input_len = prefix_len + int(input_lens[i])
|
|
||||||
re_encoded_sequence = tokenizer.encode(
|
|
||||||
prompt, add_special_tokens=False)[:total_input_len]
|
|
||||||
prompt = tokenizer.decode(re_encoded_sequence)
|
|
||||||
total_input_len = len(re_encoded_sequence)
|
|
||||||
requests.append(
|
requests.append(
|
||||||
SampleRequest(
|
SampleRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_len=total_input_len,
|
prompt_len=total_input_len,
|
||||||
expected_output_len=int(output_lens[i]),
|
expected_output_len=int(output_lens[i]),
|
||||||
request_id=request_id_prefix + str(i),
|
request_id=request_id_prefix + str(i),
|
||||||
))
|
)
|
||||||
|
)
|
||||||
return requests
|
return requests
|
||||||
|
|
||||||
|
def get_prefix(
|
||||||
|
self, tokenizer: PreTrainedTokenizerBase, prefix_len: int
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Get the prefix for the dataset.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self._rng.integers(
|
||||||
|
0, tokenizer.vocab_size, size=prefix_len).tolist()
|
||||||
|
if prefix_len > 0
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_sampling_params(
|
||||||
|
self,
|
||||||
|
num_requests: int,
|
||||||
|
range_ratio: float,
|
||||||
|
input_len: int,
|
||||||
|
output_len: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Get the sampling parameters for the dataset.
|
||||||
|
"""
|
||||||
|
# Enforce range_ratio < 1
|
||||||
|
if not (0.0 <= range_ratio < 1.0):
|
||||||
|
raise ValueError("range_ratio must be in [0, 1).")
|
||||||
|
num_special_tokens = int(tokenizer.num_special_tokens_to_add())
|
||||||
|
real_input_len = max(0, int(input_len) - num_special_tokens)
|
||||||
|
# Bounds use floor for low and ceil for high
|
||||||
|
input_low = math.floor(real_input_len * (1 - range_ratio))
|
||||||
|
input_high = math.ceil(real_input_len * (1 + range_ratio))
|
||||||
|
output_low = math.floor(output_len * (1 - range_ratio))
|
||||||
|
output_high = math.ceil(output_len * (1 + range_ratio))
|
||||||
|
# Ensure the lower bound for output length is at least 1 to
|
||||||
|
# prevent sampling 0 tokens.
|
||||||
|
output_low = max(output_low, 1)
|
||||||
|
|
||||||
|
if input_low > input_high:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid input sampling interval: "
|
||||||
|
f"low={input_low} > high={input_high}"
|
||||||
|
)
|
||||||
|
if output_low > output_high:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid output sampling interval: "
|
||||||
|
f"low={output_low} > high={output_high}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Sampling input_len from [%s, %s] and output_len from [%s, %s]",
|
||||||
|
input_low,
|
||||||
|
input_high,
|
||||||
|
output_low,
|
||||||
|
output_high,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_lens = self._rng.integers(input_low, input_high + 1,
|
||||||
|
size=num_requests)
|
||||||
|
output_lens = self._rng.integers(output_low, output_high + 1,
|
||||||
|
size=num_requests)
|
||||||
|
offsets = self._rng.integers(0, tokenizer.vocab_size,
|
||||||
|
size=num_requests)
|
||||||
|
return input_lens, output_lens, offsets
|
||||||
|
|
||||||
|
|
||||||
|
def generate_token_sequence(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
prefix_token_ids: list[int],
|
||||||
|
prefix_len: int,
|
||||||
|
vocab_size: int,
|
||||||
|
input_len: int,
|
||||||
|
offset: int,
|
||||||
|
index: int,
|
||||||
|
) -> tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Returns (prompt, total_input_len).
|
||||||
|
|
||||||
|
NOTE: After decoding the prompt we have to encode and decode it again.
|
||||||
|
This is done because in some cases N consecutive tokens
|
||||||
|
give a string tokenized into != N number of tokens.
|
||||||
|
For example for GPT2Tokenizer:
|
||||||
|
[6880, 6881] -> ['Ġcalls', 'here'] ->
|
||||||
|
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||||
|
To avoid uncontrolled change of the prompt length,
|
||||||
|
the encoded sequence is truncated before being decode again.
|
||||||
|
"""
|
||||||
|
# Build the inner sequence by sampling sequentially from the vocab
|
||||||
|
inner_seq = ((offset + index + np.arange(input_len))
|
||||||
|
% vocab_size).tolist()
|
||||||
|
token_sequence = prefix_token_ids + inner_seq
|
||||||
|
|
||||||
|
# Decode, then re-encode and truncate to preserve token count invariants
|
||||||
|
prompt = tokenizer.decode(token_sequence)
|
||||||
|
total_input_len = prefix_len + int(input_len)
|
||||||
|
|
||||||
|
re_encoded_sequence = tokenizer.encode(
|
||||||
|
prompt, add_special_tokens=False)[:total_input_len]
|
||||||
|
prompt = tokenizer.decode(re_encoded_sequence)
|
||||||
|
total_input_len = len(re_encoded_sequence)
|
||||||
|
|
||||||
|
return prompt, total_input_len
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# MultiModalDataset Implementation
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class RandomMultiModalDataset(RandomDataset):
|
||||||
|
"""
|
||||||
|
Synthetic multimodal dataset (text + images) that extends RandomDataset.
|
||||||
|
|
||||||
|
Status:
|
||||||
|
- Images: supported via synthetic RGB data.
|
||||||
|
- Video: not yet supported (TODO: implement video generation method).
|
||||||
|
- Audio: not yet supported.
|
||||||
|
|
||||||
|
Sampling overview:
|
||||||
|
1) Number of items per request is sampled uniformly from the integer range
|
||||||
|
[floor(n·(1−r)), ceil(n·(1+r))], where n is the base count and r is
|
||||||
|
`num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0.
|
||||||
|
The maximum is further clamped to the sum of per-modality limits.
|
||||||
|
2) Each item’s modality and shape is sampled from `bucket_config`, a dict
|
||||||
|
mapping (height, width, num_frames) → probability. We treat
|
||||||
|
`num_frames`=1 as image and and `num_frames` > 1 as video.
|
||||||
|
Entries with zero probability are removed and the rest are renormalized
|
||||||
|
to sum to 1.
|
||||||
|
3) Per-modality hard caps are enforced via `limit_mm_per_prompt`.
|
||||||
|
When a modality reaches its cap, all of its buckets are excluded and the
|
||||||
|
remaining probabilities are renormalized.
|
||||||
|
|
||||||
|
Example bucket configuration:
|
||||||
|
{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1}
|
||||||
|
- Two image buckets (`num_frames`=1) and one video bucket
|
||||||
|
(`num_frames`=16).
|
||||||
|
OBS.: Only image sampling is supported for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
IS_MULTIMODAL = True
|
||||||
|
# NOTE: video sampling is WIP. Setting it to 0.
|
||||||
|
DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0}
|
||||||
|
|
||||||
|
DEFAULT_BASE_ITEMS_PER_REQUEST = 1
|
||||||
|
DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0
|
||||||
|
DEFAULT_MM_ITEM_BUCKET_CONFIG = {
|
||||||
|
(256, 256, 1): 0.5,
|
||||||
|
(720, 1280, 1): 0.5,
|
||||||
|
(720, 1280, 16): 0.0,
|
||||||
|
}
|
||||||
|
DEFAULT_ENABLE_MULTIMODAL_CHAT = False
|
||||||
|
|
||||||
|
def __init__(self, **kwargs) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_synthetic_image(self, width: int, height: int) -> Image.Image:
|
||||||
|
"""Generate synthetic PIL image with random RGB values.
|
||||||
|
|
||||||
|
NOTE: iid pixel sampling results in worst-case compression
|
||||||
|
(good for stressing I/O), but very unlike real photos.
|
||||||
|
We could consider a “low-freq” mode (e.g., noise blur)
|
||||||
|
to emulate network realism instead of max stress.
|
||||||
|
"""
|
||||||
|
random_pixels = self._rng.integers(
|
||||||
|
0,
|
||||||
|
256,
|
||||||
|
(height, width, 3),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
return Image.fromarray(random_pixels)
|
||||||
|
|
||||||
|
def generate_synthetic_video(self, width: int,
|
||||||
|
height: int,
|
||||||
|
num_frames: int) -> Any:
|
||||||
|
"""Generate synthetic video with random values.
|
||||||
|
|
||||||
|
TODO: Finish this method.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Video sampling is WIP.")
|
||||||
|
|
||||||
|
def map_config_to_modality(self, config: tuple[int, int, int]) -> str:
|
||||||
|
"""Map the configuration to the modality."""
|
||||||
|
if config[-1] == 1:
|
||||||
|
return "image"
|
||||||
|
elif config[-1] > 1:
|
||||||
|
return "video"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid multimodal item configuration: {config}")
|
||||||
|
|
||||||
|
def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int],
|
||||||
|
float]) -> dict[tuple[int, int, int], float]:
|
||||||
|
"""
|
||||||
|
Remove zero probability entries
|
||||||
|
and normalize the bucket config to sum to 1.
|
||||||
|
"""
|
||||||
|
# Raise error if value is negative
|
||||||
|
if any(v < 0 for v in bucket_config.values()):
|
||||||
|
raise ValueError("Bucket config values must be non-negative.")
|
||||||
|
# Remove zero probability entries
|
||||||
|
bucket_config = {k: v for k, v in bucket_config.items() if v > 0}
|
||||||
|
# if bucket config is empty, raise error
|
||||||
|
if not bucket_config:
|
||||||
|
raise ValueError("Got invalid bucket config. "
|
||||||
|
"Bucket config values must be non-zero.")
|
||||||
|
# Normalize the remaining bucket config to sum to 1
|
||||||
|
total = sum(bucket_config.values())
|
||||||
|
return {k: v / total for k, v in bucket_config.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def generate_mm_item(self,
|
||||||
|
mm_item_config: tuple[int, int, int],
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Create synthetic images and videos and
|
||||||
|
apply process_image/process_video respectively.
|
||||||
|
This follows the OpenAI API chat completions
|
||||||
|
https://github.com/openai/openai-python
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.map_config_to_modality(mm_item_config) == "image":
|
||||||
|
return process_image(self.generate_synthetic_image(
|
||||||
|
mm_item_config[1],
|
||||||
|
mm_item_config[0]))
|
||||||
|
elif self.map_config_to_modality(mm_item_config) == "video":
|
||||||
|
return process_video(self.generate_synthetic_video(
|
||||||
|
mm_item_config[1],
|
||||||
|
mm_item_config[0],
|
||||||
|
mm_item_config[2]))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid multimodal item configuration: "
|
||||||
|
f"{mm_item_config}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_mm_item_sampling_params(
|
||||||
|
self,
|
||||||
|
base_items_per_request: int,
|
||||||
|
num_mm_items_range_ratio: float,
|
||||||
|
limit_mm_per_prompt: dict[str, int],
|
||||||
|
bucket_config: dict[tuple[int, int, int], float],
|
||||||
|
) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]:
|
||||||
|
"""
|
||||||
|
Get the sampling parameters for the multimodal items.
|
||||||
|
"""
|
||||||
|
# Enforce num_mm_items_range_ratio <= 1
|
||||||
|
if not (0.0 <= num_mm_items_range_ratio <= 1.0):
|
||||||
|
raise ValueError("num_mm_items_range_ratio must be in [0, 1].")
|
||||||
|
|
||||||
|
# Ensure modalities to sample are in limit_mm_per_prompt
|
||||||
|
for k, v in bucket_config.items():
|
||||||
|
# get modality from bucket config
|
||||||
|
modality = self.map_config_to_modality(k)
|
||||||
|
if modality not in limit_mm_per_prompt:
|
||||||
|
raise ValueError(f"Modality {modality} is not in "
|
||||||
|
f"limit_mm_per_prompt: "
|
||||||
|
f"{limit_mm_per_prompt.keys()}")
|
||||||
|
|
||||||
|
# Remove zero probability entries
|
||||||
|
# and normalize bucket config to sum to 1
|
||||||
|
bucket_config = self.normalize_bucket_config(bucket_config)
|
||||||
|
logger.info(
|
||||||
|
"Normalized bucket config: %s", bucket_config,
|
||||||
|
)
|
||||||
|
# Only consider limit per prompt for modalities in bucket config
|
||||||
|
allowed_modalities = {self.map_config_to_modality(cfg)
|
||||||
|
for cfg in bucket_config}
|
||||||
|
limit_mm_per_prompt = {
|
||||||
|
k: v for k, v in limit_mm_per_prompt.items()
|
||||||
|
if k in allowed_modalities}
|
||||||
|
if not limit_mm_per_prompt:
|
||||||
|
raise ValueError("No valid limits for modalities present in "
|
||||||
|
"bucket_config.")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Updated mm-limit-per-prompt: %s", limit_mm_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get max and min num mm items and ensure
|
||||||
|
# it is at most the sum of limit_mm_per_prompt for all modalities
|
||||||
|
max_num_mm_items = min(
|
||||||
|
sum(limit_mm_per_prompt.values()),
|
||||||
|
math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio))
|
||||||
|
)
|
||||||
|
# Ensure min num mm items is at least 0
|
||||||
|
min_num_mm_items = max(
|
||||||
|
0,
|
||||||
|
math.floor(base_items_per_request * (1 - num_mm_items_range_ratio))
|
||||||
|
)
|
||||||
|
# Raise error if min num mm items is greater than max num mm items
|
||||||
|
if min_num_mm_items > max_num_mm_items:
|
||||||
|
raise ValueError(f"Min num mm items is greater than max mm items: "
|
||||||
|
f"{min_num_mm_items} > {max_num_mm_items}")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Sampling number of multimodal items from [%s, %s]",
|
||||||
|
min_num_mm_items, max_num_mm_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
min_num_mm_items,
|
||||||
|
max_num_mm_items,
|
||||||
|
limit_mm_per_prompt,
|
||||||
|
bucket_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_mm_item_iterator(
|
||||||
|
self,
|
||||||
|
min_num_mm_items: int,
|
||||||
|
max_num_mm_items: int,
|
||||||
|
bucket_config: dict[tuple[int, int, int], float],
|
||||||
|
limit_mm_per_prompt: dict[str, int],
|
||||||
|
) -> Iterator[tuple[int,int, int]]:
|
||||||
|
"""
|
||||||
|
Iterator over the multimodal items for each request
|
||||||
|
whose size is between min_num_mm_items and max_num_mm_items.
|
||||||
|
|
||||||
|
Loop over the bucket config and sample a multimodal item.
|
||||||
|
Loop until the number of multimodal items sampled is equal to
|
||||||
|
request_num_mm_items or limit of multimodal items per prompt
|
||||||
|
for all modalities is reached.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- This function operates on a per-request shallow copy of
|
||||||
|
`bucket_config` (tuple->float). The original dict passed to
|
||||||
|
`sample` is not mutated. If this ever changes, a test
|
||||||
|
is implemented and will fail.
|
||||||
|
"""
|
||||||
|
# Get the number of multimodal items to sample
|
||||||
|
request_num_mm_items = int(
|
||||||
|
self._rng.integers(min_num_mm_items, max_num_mm_items + 1)
|
||||||
|
)
|
||||||
|
# If request_num_mm_items is 0, yield an empty iterator
|
||||||
|
if request_num_mm_items == 0:
|
||||||
|
return
|
||||||
|
# Initialize modality counters
|
||||||
|
modality_counter = {self.map_config_to_modality(k): 0
|
||||||
|
for k in bucket_config}
|
||||||
|
# Copy the bucket config to avoid modifying the original
|
||||||
|
bucket_config_copy = bucket_config.copy()
|
||||||
|
# Loop over the number of multimodal items to sample
|
||||||
|
while sum(modality_counter.values()) < request_num_mm_items:
|
||||||
|
# Sample a multimodal item config
|
||||||
|
mm_item_config = self._rng.choice(list(bucket_config_copy.keys()),
|
||||||
|
p=list(bucket_config_copy.values()))
|
||||||
|
modality = self.map_config_to_modality(mm_item_config)
|
||||||
|
# Check that modality count is less than limit per prompt
|
||||||
|
if modality_counter[modality] < limit_mm_per_prompt[modality]:
|
||||||
|
modality_counter[modality] += 1
|
||||||
|
yield (
|
||||||
|
mm_item_config
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If the counter is greater than the limit per prompt
|
||||||
|
# set all multimodal items of this modality to 0
|
||||||
|
for k, v in bucket_config_copy.items():
|
||||||
|
if self.map_config_to_modality(k) == modality:
|
||||||
|
bucket_config_copy[k] = 0
|
||||||
|
# If all configs are 0, break the loop
|
||||||
|
# This should not happen as request_num_mm_items is at most
|
||||||
|
# the sum of limit_mm_per_prompt for all modalities
|
||||||
|
if all(v == 0 for v in bucket_config_copy.values()):
|
||||||
|
logger.warning("Exhausted all multimodal items "
|
||||||
|
"of modality %s",
|
||||||
|
modality)
|
||||||
|
break
|
||||||
|
# Renormalize the bucket config
|
||||||
|
bucket_config_copy = self.normalize_bucket_config(
|
||||||
|
bucket_config_copy)
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
num_requests: int,
|
||||||
|
request_id_prefix: str = "",
|
||||||
|
prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN,
|
||||||
|
range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO,
|
||||||
|
input_len: int = RandomDataset.DEFAULT_INPUT_LEN,
|
||||||
|
output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN,
|
||||||
|
limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT,
|
||||||
|
base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST,
|
||||||
|
num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
|
||||||
|
bucket_config: dict[tuple[int, int, int], float] =
|
||||||
|
DEFAULT_MM_ITEM_BUCKET_CONFIG,
|
||||||
|
enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[SampleRequest]:
|
||||||
|
|
||||||
|
# NOTE: Video sampling is WIP. Raise error if video is in bucket config
|
||||||
|
# and probability is non-zero.
|
||||||
|
if any(self.map_config_to_modality(cfg) == "video" and p > 0
|
||||||
|
for cfg, p in bucket_config.items()):
|
||||||
|
raise NotImplementedError("Video sampling not implemented; "
|
||||||
|
"set its probability to 0.")
|
||||||
|
|
||||||
|
# Get the sampling parameters for the dataset
|
||||||
|
input_lens, output_lens, offsets = self.get_sampling_params(
|
||||||
|
num_requests, range_ratio, input_len, output_len, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
min_num_mm_items,
|
||||||
|
max_num_mm_items,
|
||||||
|
limit_mm_per_prompt,
|
||||||
|
bucket_config,
|
||||||
|
) = self.get_mm_item_sampling_params(
|
||||||
|
base_items_per_request,
|
||||||
|
num_mm_items_range_ratio,
|
||||||
|
limit_mm_per_prompt,
|
||||||
|
bucket_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate prefix once
|
||||||
|
prefix_token_ids = self.get_prefix(tokenizer, prefix_len)
|
||||||
|
vocab_size = tokenizer.vocab_size
|
||||||
|
# Add synthetic multimodal items to each request
|
||||||
|
mm_requests = []
|
||||||
|
for i in range(num_requests):
|
||||||
|
prompt, total_input_len = self.generate_token_sequence(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
prefix_token_ids=prefix_token_ids,
|
||||||
|
prefix_len=prefix_len,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
input_len=int(input_lens[i]),
|
||||||
|
offset=int(offsets[i]),
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
# Get multimodal item iterator for a given request
|
||||||
|
mm_item_iterator = self.get_mm_item_iterator(
|
||||||
|
min_num_mm_items,
|
||||||
|
max_num_mm_items,
|
||||||
|
bucket_config,
|
||||||
|
limit_mm_per_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
mm_content = cast(list[dict[str, Any]], [
|
||||||
|
self.generate_mm_item(mm_item_config)
|
||||||
|
for mm_item_config in mm_item_iterator
|
||||||
|
])
|
||||||
|
|
||||||
|
if enable_multimodal_chat:
|
||||||
|
# NOTE: For now this option is only provided for completeness
|
||||||
|
# given that the serve.py benchmark currently does not use it.
|
||||||
|
mm_chat_prompt: Any = prompt
|
||||||
|
mm_chat_prompt = self.apply_multimodal_chat_transformation(
|
||||||
|
prompt, mm_content)
|
||||||
|
sample_request = SampleRequest(
|
||||||
|
prompt=mm_chat_prompt,
|
||||||
|
prompt_len=total_input_len,
|
||||||
|
expected_output_len=int(output_lens[i]),
|
||||||
|
multi_modal_data=None,
|
||||||
|
request_id=request_id_prefix + str(i),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample_request = SampleRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_len=total_input_len,
|
||||||
|
expected_output_len=int(output_lens[i]),
|
||||||
|
multi_modal_data=mm_content,
|
||||||
|
request_id=request_id_prefix + str(i),
|
||||||
|
)
|
||||||
|
mm_requests.append(sample_request)
|
||||||
|
return mm_requests
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# ShareGPT Dataset Implementation
|
# ShareGPT Dataset Implementation
|
||||||
@ -545,8 +1004,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
type=str,
|
type=str,
|
||||||
default="random",
|
default="random",
|
||||||
choices=[
|
choices=[
|
||||||
"sharegpt", "burstgpt", "sonnet", "random", "hf", "custom",
|
"sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf",
|
||||||
"prefix_repetition"
|
"custom", "prefix_repetition"
|
||||||
],
|
],
|
||||||
help="Name of the dataset to benchmark on.",
|
help="Name of the dataset to benchmark on.",
|
||||||
)
|
)
|
||||||
@ -647,6 +1106,98 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
|
|||||||
"input_len * (1 + range_ratio)]."),
|
"input_len * (1 + range_ratio)]."),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# random multimodal dataset options
|
||||||
|
random_mm_group = parser.add_argument_group(
|
||||||
|
"random multimodal dataset options extended from random dataset")
|
||||||
|
random_mm_group.add_argument(
|
||||||
|
"--random-mm-base-items-per-request",
|
||||||
|
type=int,
|
||||||
|
default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST,
|
||||||
|
help=(
|
||||||
|
"Base number of multimodal items per request for random-mm. "
|
||||||
|
"Actual per-request count is sampled around this base using "
|
||||||
|
"--random-mm-num-mm-items-range-ratio."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
random_mm_group.add_argument(
|
||||||
|
"--random-mm-num-mm-items-range-ratio",
|
||||||
|
type=float,
|
||||||
|
default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO,
|
||||||
|
help=(
|
||||||
|
"Range ratio r in [0, 1] for sampling items per request. "
|
||||||
|
"We sample uniformly from the closed integer range "
|
||||||
|
"[floor(n*(1-r)), ceil(n*(1+r))] "
|
||||||
|
"where n is the base items per request. "
|
||||||
|
"r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped "
|
||||||
|
"to the sum of per-modality limits from "
|
||||||
|
"--random-mm-limit-mm-per-prompt. "
|
||||||
|
"An error is raised if the computed min exceeds the max."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
random_mm_group.add_argument(
|
||||||
|
"--random-mm-limit-mm-per-prompt",
|
||||||
|
type=json.loads,
|
||||||
|
default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT,
|
||||||
|
help=(
|
||||||
|
"Per-modality hard caps for items attached per request, e.g. "
|
||||||
|
"'{\"image\": 3, \"video\": 0}'. The sampled per-request item "
|
||||||
|
"count is clamped to the sum of these limits. When a modality "
|
||||||
|
"reaches its cap, its buckets are excluded and probabilities are "
|
||||||
|
"renormalized."
|
||||||
|
"OBS.: Only image sampling is supported for now."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]:
|
||||||
|
# If already a dict (e.g., programmatic call), normalize keys
|
||||||
|
def normalize(d: dict) -> dict[tuple[int, int, int], float]:
|
||||||
|
out: dict[tuple[int, int, int], float] = {}
|
||||||
|
for k, val in d.items():
|
||||||
|
key = k
|
||||||
|
if isinstance(key, str):
|
||||||
|
with suppress(Exception):
|
||||||
|
key = ast.literal_eval(key)
|
||||||
|
if not (isinstance(key, tuple) and len(key) == 3
|
||||||
|
and all(isinstance(x, int) for x in key)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid bucket key {k!r}. Expected tuple (H, W, T)."
|
||||||
|
)
|
||||||
|
out[(int(key[0]), int(key[1]), int(key[2]))] = float(val)
|
||||||
|
return out
|
||||||
|
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return normalize(v)
|
||||||
|
if isinstance(v, str):
|
||||||
|
# Python literal (supports tuple keys)
|
||||||
|
parsed = ast.literal_eval(v)
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError("Bucket config must parse to a dict.")
|
||||||
|
return normalize(parsed)
|
||||||
|
raise ValueError("Unsupported value for --random-mm-bucket-config.")
|
||||||
|
|
||||||
|
random_mm_group.add_argument(
|
||||||
|
"--random-mm-bucket-config",
|
||||||
|
type=_parse_mm_bucket_config,
|
||||||
|
default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG,
|
||||||
|
help=(
|
||||||
|
"The bucket config is a dictionary mapping a multimodal item"
|
||||||
|
"sampling configuration to a probability."
|
||||||
|
"Currently allows for 2 modalities: images and videos. "
|
||||||
|
"An bucket key is a tuple of (height, width, num_frames)"
|
||||||
|
"The value is the probability of sampling that specific item. "
|
||||||
|
"Example: "
|
||||||
|
"--random-mm-bucket-config "
|
||||||
|
"{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} "
|
||||||
|
"First item: images with resolution 256x256 w.p. 0.5"
|
||||||
|
"Second item: images with resolution 720x1280 w.p. 0.4 "
|
||||||
|
"Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1"
|
||||||
|
"OBS.: If the probabilities do not sum to 1, they are normalized."
|
||||||
|
"OBS bis.: Only image sampling is supported for now."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
hf_group = parser.add_argument_group("hf dataset options")
|
hf_group = parser.add_argument_group("hf dataset options")
|
||||||
hf_group.add_argument("--hf-subset",
|
hf_group.add_argument("--hf-subset",
|
||||||
type=str,
|
type=str,
|
||||||
@ -821,6 +1372,22 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
range_ratio=args.random_range_ratio,
|
range_ratio=args.random_range_ratio,
|
||||||
request_id_prefix=args.request_id_prefix,
|
request_id_prefix=args.request_id_prefix,
|
||||||
),
|
),
|
||||||
|
"random-mm":
|
||||||
|
lambda: RandomMultiModalDataset(
|
||||||
|
random_seed=args.seed, dataset_path=args.dataset_path
|
||||||
|
).sample(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
prefix_len=args.random_prefix_len,
|
||||||
|
range_ratio=args.random_range_ratio,
|
||||||
|
input_len=args.random_input_len,
|
||||||
|
output_len=args.random_output_len,
|
||||||
|
base_items_per_request=args.random_mm_base_items_per_request,
|
||||||
|
limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt,
|
||||||
|
num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio,
|
||||||
|
bucket_config=args.random_mm_bucket_config,
|
||||||
|
request_id_prefix=args.request_id_prefix,
|
||||||
|
),
|
||||||
"prefix_repetition":
|
"prefix_repetition":
|
||||||
lambda: PrefixRepetitionRandomDataset(
|
lambda: PrefixRepetitionRandomDataset(
|
||||||
random_seed=args.seed, dataset_path=args.dataset_path
|
random_seed=args.seed, dataset_path=args.dataset_path
|
||||||
@ -836,6 +1403,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Enforce endpoint compatibility for multimodal datasets.
|
||||||
|
if args.dataset_name == "random-mm" and args.endpoint_type not in [
|
||||||
|
"openai-chat"]:
|
||||||
|
raise ValueError(
|
||||||
|
"Multi-modal content (images) is only supported on "
|
||||||
|
"'openai-chat' backend."
|
||||||
|
)
|
||||||
input_requests = dataset_mapping[args.dataset_name]()
|
input_requests = dataset_mapping[args.dataset_name]()
|
||||||
except KeyError as err:
|
except KeyError as err:
|
||||||
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
raise ValueError(f"Unknown dataset: {args.dataset_name}") from err
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user