Signed-off-by: Reagan Lee <reaganjlee@gmail.com>
Signed-off-by: Reagan <reaganjlee@gmail.com>
This commit is contained in:
Reagan 2025-11-20 13:06:52 -05:00
parent e087fbc393
commit e3dd9108cb
14 changed files with 858 additions and 81 deletions

View File

@ -0,0 +1,9 @@
# vllm bench multimodal-processor
## JSON CLI Arguments
--8<-- "docs/cli/json_tip.inc.md"
## Arguments
--8<-- "docs/argparse/bench_multimodal_processor.inc.md"

View File

@ -92,6 +92,7 @@ def auto_mock(module_name: str, attr: str, max_mocks: int = 100):
bench_latency = auto_mock("vllm.benchmarks", "latency")
bench_multimodal_processor = auto_mock("vllm.benchmarks", "multimodal_processor")
bench_serve = auto_mock("vllm.benchmarks", "serve")
bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs")
bench_sweep_plot_pareto = auto_mock(
@ -222,6 +223,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
"run-batch": create_parser(openai_run_batch.make_arg_parser),
# Benchmark CLI
"bench_latency": create_parser(bench_latency.add_cli_args),
"bench_multimodal_processor": create_parser(bench_multimodal_processor.add_cli_args),
"bench_serve": create_parser(bench_serve.add_cli_args),
"bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
"bench_sweep_plot_pareto": create_parser(bench_sweep_plot_pareto.add_cli_args),

View File

@ -12,7 +12,11 @@ import torch
import torch.nn as nn
from PIL import Image
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.config import (
ModelConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.config.multimodal import (
AudioDummyOptions,
BaseDummyOptions,

View File

@ -1908,7 +1908,8 @@ def get_samples(args, tokenizer: TokenizerLike) -> list[SampleRequest]:
try:
# Enforce endpoint compatibility for multimodal datasets.
if args.dataset_name == "random-mm" and args.backend not in ["openai-chat"]:
backend = getattr(args, "backend", "openai-chat")
if args.dataset_name == "random-mm" and backend not in ["openai-chat"]:
raise ValueError(
"Multi-modal content (images) is only supported on "
"'openai-chat' backend."

View File

@ -0,0 +1,464 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
r"""Benchmark multimodal processor latency.
This benchmark measures the latency of the multimodal processor module
using randomly generated multimodal prompts with synthetic images.
MM processor stats are automatically enabled.
Run:
vllm bench multimodal-processor \
--model <your_model> \
--num-prompts 10 \
--input-len 1024 \
--output-len 128 \
--num-images 1
"""
import argparse
import dataclasses
import json
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Any
import numpy as np
from vllm.engine.arg_utils import EngineArgs
from vllm.multimodal.processing import (
get_timing_stats_from_engine_client,
)
from vllm.utils.gc_utils import freeze_gc_heap
from vllm.utils.import_utils import PlaceholderModule
try:
import pandas as pd
except ImportError:
pd = PlaceholderModule("pandas")
@dataclass
class MultimodalProcessorBenchmarkMetrics:
"""Metrics for multimodal processor benchmark."""
completed: int
failed: int
mean_e2el_ms: float
median_e2el_ms: float
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
"""Per-stage timing stats: mean, median, std, percentiles for each stage."""
mm_processor_stats: dict[str, dict[str, float]]
def collect_mm_processor_stats(
llm_engine: Any,
debug: bool = False,
) -> dict[str, list[float]]:
"""
Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds).
"""
all_stats = get_timing_stats_from_engine_client(llm_engine)
stats_by_stage = {
"hf_processor_time": [],
"hashing_time": [],
"cache_lookup_time": [],
"prompt_update_time": [],
"total_time": [],
}
for stats_dict in all_stats.values():
stats_by_stage["hf_processor_time"].append(
stats_dict.get("hf_processor_time", 0.0)
)
stats_by_stage["hashing_time"].append(stats_dict.get("hashing_time", 0.0))
stats_by_stage["cache_lookup_time"].append(
stats_dict.get("cache_lookup_time", 0.0)
)
stats_by_stage["prompt_update_time"].append(
stats_dict.get("prompt_update_time", 0.0)
)
stats_by_stage["total_time"].append(stats_dict.get("total_time", 0.0))
if debug and not any(stats_by_stage.values()):
print(
"Warning: No MM processor stats found. Ensure --enable-mm-processor-stats is set."
)
return stats_by_stage
def calculate_mm_processor_metrics(
stats_by_stage: dict[str, list[float]],
selected_percentiles: list[float],
) -> dict[str, dict[str, float]]:
"""
Calculate aggregate metrics from stats by stage.
"""
metrics = {}
for stage_name, times in stats_by_stage.items():
if not times:
metrics[stage_name] = {
"mean": 0.0,
"median": 0.0,
"std": 0.0,
**{f"p{p}": 0.0 for p in selected_percentiles},
}
continue
times_ms = [t * 1000 for t in times]
metrics[stage_name] = {
"mean": float(np.mean(times_ms)),
"median": float(np.median(times_ms)),
"std": float(np.std(times_ms)),
**{
f"p{p}": float(np.percentile(times_ms, p)) for p in selected_percentiles
},
}
return metrics
def generate_random_multimodal_prompts(
num_prompts: int,
input_len: int,
output_len: int,
tokenizer: Any,
num_images: int = 1,
image_width: int = 256,
image_height: int = 256,
seed: int = 0,
) -> tuple[list[list[dict]], list[int]]:
"""
Generate random multimodal prompts with synthetic images and text tokens.
Returns:
tuple: (prompts, expected_output_lens)
- prompts: List of OpenAI chat format messages with text and images
- expected_output_lens: List of expected output lengths
"""
from PIL import Image
from vllm.benchmarks.datasets import process_image
rng = np.random.default_rng(seed)
prompts = []
expected_output_lens = []
for i in range(num_prompts):
vocab_size = tokenizer.vocab_size
prompt_token_ids = rng.integers(
0, vocab_size, size=input_len
).tolist()
text_prompt = tokenizer.decode(prompt_token_ids)
mm_items = []
for _ in range(num_images):
# Generate random RGB image
random_pixels = rng.integers(
0, 256, (image_height, image_width, 3), dtype=np.uint8
)
image = Image.fromarray(random_pixels)
# Process to OpenAI format
mm_item = process_image(image)
mm_items.append(mm_item)
# Create chat format: text + images
content = [{"type": "text", "text": text_prompt}]
content.extend(mm_items)
prompts.append([{"role": "user", "content": content}])
expected_output_lens.append(output_len)
return prompts, expected_output_lens
def benchmark_multimodal_processor(
args: argparse.Namespace,
) -> dict[str, Any]:
"""
Run the multimodal processor benchmark.
"""
from vllm import LLM, SamplingParams
engine_args = EngineArgs.from_cli_args(args)
llm = LLM(**dataclasses.asdict(engine_args))
# Validate max_model_len
assert llm.llm_engine.model_config.max_model_len >= (
args.input_len + args.output_len
), (
"Please ensure that max_model_len is greater than "
"the sum of input_len and output_len."
)
# Generate random multimodal prompts
seed = getattr(args, "seed", 0)
tokenizer = llm.get_tokenizer()
prompts, expected_output_lens = generate_random_multimodal_prompts(
num_prompts=args.num_prompts,
input_len=args.input_len,
output_len=args.output_len,
tokenizer=tokenizer,
num_images=args.num_images,
image_width=args.image_width,
image_height=args.image_height,
seed=seed,
)
# Create sampling params
sampling_params = [
SamplingParams(
n=1,
temperature=0.0, # Greedy sampling for deterministic speed benchmarks
max_tokens=output_len,
detokenize=True,
)
for output_len in expected_output_lens
]
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
freeze_gc_heap()
# MM processor stats are automatically enabled via set_defaults
# No need to check or raise error
debug = getattr(args, "debug_mm_stats", False)
print(f"Processing {len(prompts)} requests...")
start_time = time.perf_counter()
outputs = llm.chat(
prompts, sampling_params, use_tqdm=not getattr(args, "disable_tqdm", False)
)
end_time = time.perf_counter()
total_time = end_time - start_time
mm_stats_by_stage = collect_mm_processor_stats(
llm.llm_engine,
debug=debug,
)
if not any(mm_stats_by_stage.values()):
print(
"\n⚠️ Warning: No MM processor stats found in registry.\n"
" This may indicate that:\n"
" - No multimodal requests were processed\n"
" - Stats were already retrieved (registry is cleared after retrieval)\n"
)
mm_processor_metrics = calculate_mm_processor_metrics(
mm_stats_by_stage, selected_percentiles
)
completed = len([o for o in outputs if o.finished])
failed = len(outputs) - completed
e2el_times = []
for output in outputs:
if not output.finished or output.metrics is None:
continue
metrics = output.metrics
for attr in ("finished_time", "last_token_time"):
if (
getattr(metrics, attr, None) is not None
and getattr(metrics, "arrival_time", None) is not None
):
e2el_times.append(
(getattr(metrics, attr) - metrics.arrival_time) * 1000
)
break
if not e2el_times and completed > 0:
avg_time_per_request = total_time / completed
e2el_times = [avg_time_per_request * 1000] * completed
if e2el_times:
mean_e2el_ms = float(np.mean(e2el_times))
median_e2el_ms = float(np.median(e2el_times))
std_e2el_ms = float(np.std(e2el_times))
percentiles_e2el_ms = [
(p, float(np.percentile(e2el_times, p))) for p in selected_percentiles
]
else:
mean_e2el_ms = 0.0
median_e2el_ms = 0.0
std_e2el_ms = 0.0
percentiles_e2el_ms = [(p, 0.0) for p in selected_percentiles]
benchmark_result = {
"completed": completed,
"failed": failed,
"mean_e2el_ms": mean_e2el_ms,
"median_e2el_ms": median_e2el_ms,
"std_e2el_ms": std_e2el_ms,
"percentiles_e2el_ms": percentiles_e2el_ms,
"mm_processor_stats": mm_processor_metrics,
}
return benchmark_result
def add_cli_args(parser: argparse.ArgumentParser) -> None:
"""Add CLI arguments for the multimodal processor benchmark."""
from vllm.engine.arg_utils import EngineArgs
# Add EngineArgs (no conflict since we removed dataset parser)
EngineArgs.add_cli_args(parser)
# Automatically enable MM processor stats (required for this benchmark)
parser.set_defaults(enable_mm_processor_stats=True)
# Random generation arguments (similar to latency.py)
parser.add_argument(
"--num-prompts",
type=int,
default=10,
help="Number of prompts to process.",
)
parser.add_argument(
"--input-len",
type=int,
default=1024,
help="Number of input tokens per request.",
)
parser.add_argument(
"--output-len",
type=int,
default=128,
help="Number of output tokens per request.",
)
parser.add_argument(
"--num-images",
type=int,
default=1,
help="Number of images per prompt.",
)
parser.add_argument(
"--image-width",
type=int,
default=256,
help="Width of generated images in pixels.",
)
parser.add_argument(
"--image-height",
type=int,
default=256,
help="Height of generated images in pixels.",
)
parser.add_argument(
"--output-json",
type=str,
default=None,
help="Path to save the benchmark results in JSON format.",
)
parser.add_argument(
"--debug-mm-stats",
action="store_true",
help="Enable debug logging for MM processor stats collection.",
)
parser.add_argument(
"--metric-percentiles",
type=str,
default="99",
help="Comma-separated list of percentiles to calculate (e.g., '50,90,99').",
)
parser.add_argument(
"--disable-tqdm",
action="store_true",
help="Disable tqdm progress bar.",
)
def main(args: argparse.Namespace) -> None:
"""Main entry point for the multimodal processor benchmark."""
from datetime import datetime
print("Starting multimodal processor benchmark...")
result = benchmark_multimodal_processor(args)
print("\n" + "=" * 80)
print("Multimodal Processor Benchmark Results")
print("=" * 80)
if "mm_processor_stats" in result:
print("\nMM Processor Timing (ms):")
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
mm_data = []
for stage, metrics in result["mm_processor_stats"].items():
row = {
"Stage": stage,
"Mean": f"{metrics['mean']:.2f}",
"Median": f"{metrics['median']:.2f}",
"Std": f"{metrics['std']:.2f}",
}
for p in selected_percentiles:
row[f"P{p}"] = f"{metrics.get(f'p{p}', 0.0):.2f}"
mm_data.append(row)
mm_df = pd.DataFrame(mm_data)
print(mm_df.to_string(index=False))
if "mean_e2el_ms" in result:
print("\nEnd-to-End Latency (ms):")
selected_percentiles = [
float(p) for p in getattr(args, "metric_percentiles", "99").split(",")
]
e2el_data = [
{"Metric": "Mean", "Value (ms)": f"{result['mean_e2el_ms']:.2f}"},
{"Metric": "Median", "Value (ms)": f"{result['median_e2el_ms']:.2f}"},
{"Metric": "Std", "Value (ms)": f"{result['std_e2el_ms']:.2f}"},
]
for p in selected_percentiles:
percentile_value = next(
(val for pct, val in result["percentiles_e2el_ms"] if pct == p),
0.0,
)
e2el_data.append(
{
"Metric": f"P{p}",
"Value (ms)": f"{percentile_value:.2f}",
}
)
e2el_df = pd.DataFrame(e2el_data)
print(e2el_df.to_string(index=False))
if args.output_json:
result["config"] = {
"model": args.model,
"num_prompts": args.num_prompts,
"input_len": args.input_len,
"output_len": args.output_len,
"num_images": args.num_images,
"image_width": args.image_width,
"image_height": args.image_height,
}
result["timestamp"] = datetime.now().isoformat()
with open(args.output_json, "w") as f:
json.dump(result, f, indent=2)
print(f"\nResults saved to {args.output_json}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark multimodal processor latency"
)
add_cli_args(parser)
args = parser.parse_args()
main(args)

View File

@ -64,6 +64,11 @@ class ObservabilityConfig:
module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
enable_mm_processor_stats: bool = False
"""Enable collection of timing statistics for multimodal processor operations.
This can be useful for performance analysis and debugging. Defaults to `False`
(disabled)."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""

View File

@ -521,6 +521,7 @@ class EngineArgs:
enable_layerwise_nvtx_tracing: bool = (
ObservabilityConfig.enable_layerwise_nvtx_tracing
)
enable_mm_processor_stats: bool = ObservabilityConfig.enable_mm_processor_stats
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
@ -1040,6 +1041,10 @@ class EngineArgs:
"--enable-layerwise-nvtx-tracing",
**observability_kwargs["enable_layerwise_nvtx_tracing"],
)
observability_group.add_argument(
"--enable-mm-processor-stats",
**observability_kwargs["enable_mm_processor_stats"],
)
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
@ -1682,6 +1687,7 @@ class EngineArgs:
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
cudagraph_metrics=self.cudagraph_metrics,
enable_layerwise_nvtx_tracing=self.enable_layerwise_nvtx_tracing,
enable_mm_processor_stats=self.enable_mm_processor_stats,
)
# Compilation config overrides

View File

@ -1,6 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
from vllm.entrypoints.cli.benchmark.multimodal_processor import (
BenchmarkMultimodalProcessorSubcommand,
)
from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand
from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
@ -8,6 +11,7 @@ from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcomm
__all__: list[str] = [
"BenchmarkLatencySubcommand",
"BenchmarkMultimodalProcessorSubcommand",
"BenchmarkServingSubcommand",
"BenchmarkStartupSubcommand",
"BenchmarkSweepSubcommand",

View File

@ -0,0 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from vllm.benchmarks.multimodal_processor import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
class BenchmarkMultimodalProcessorSubcommand(BenchmarkSubcommandBase):
"""The `multimodal-processor` subcommand for `vllm bench`."""
name = "multimodal-processor"
help = "Benchmark multimodal processor latency across different configurations."
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)

View File

@ -651,9 +651,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"NO_COLOR": lambda: os.getenv("NO_COLOR", "0") != "0",
# If set, vllm will log stats at this interval in seconds
# If not set, vllm will log stats every 10 seconds.
"VLLM_LOG_STATS_INTERVAL": lambda: val
if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) > 0.0
else 10.0,
"VLLM_LOG_STATS_INTERVAL": lambda: (
val
if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) > 0.0
else 10.0
),
# Trace function calls
# If set to 1, vllm will trace function calls
# Useful for debugging
@ -678,28 +680,30 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
),
# If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER": lambda: bool(
int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])
)
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ
else None,
"VLLM_USE_FLASHINFER_SAMPLER": lambda: (
bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ
else None
),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# (CPU backend only) CPU key-value cache space.
# default is None and will be set as 4 GB
"VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0"))
if "VLLM_CPU_KVCACHE_SPACE" in os.environ
else None,
"VLLM_CPU_KVCACHE_SPACE": lambda: (
int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0"))
if "VLLM_CPU_KVCACHE_SPACE" in os.environ
else None
),
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
"VLLM_CPU_OMP_THREADS_BIND": lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"),
# (CPU backend only) CPU cores not used by OMP threads .
# Those CPU cores will not be used by OMP threads of a rank.
"VLLM_CPU_NUM_OF_RESERVED_CPU": lambda: int(
os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")
)
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ
else None,
"VLLM_CPU_NUM_OF_RESERVED_CPU": lambda: (
int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0"))
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ
else None
),
# (CPU backend only) whether to use SGL kernels, optimized for small batch.
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
# If the env var is set, Ray Compiled Graph uses the specified
@ -843,9 +847,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded
"VLLM_PLUGINS": lambda: None
if "VLLM_PLUGINS" not in os.environ
else os.environ["VLLM_PLUGINS"].split(","),
"VLLM_PLUGINS": lambda: (
None
if "VLLM_PLUGINS" not in os.environ
else os.environ["VLLM_PLUGINS"].split(",")
),
# a local directory to look in for unrecognized LoRA adapters.
# only works if plugins are enabled and
# VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
@ -917,9 +923,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# and performance comparisons. Currently only affects MPLinearKernel
# selection
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
"VLLM_DISABLED_KERNELS": lambda: []
if "VLLM_DISABLED_KERNELS" not in os.environ
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
"VLLM_DISABLED_KERNELS": lambda: (
[]
if "VLLM_DISABLED_KERNELS" not in os.environ
else os.environ["VLLM_DISABLED_KERNELS"].split(",")
),
# Disable pynccl (using torch.distributed instead)
"VLLM_DISABLE_PYNCCL": lambda: (
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
@ -1155,11 +1163,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
== "1",
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP": lambda: int(
os.environ["VLLM_TPU_BUCKET_PADDING_GAP"]
)
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ
else 0,
"VLLM_TPU_BUCKET_PADDING_GAP": lambda: (
int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ
else 0
),
"VLLM_TPU_MOST_MODEL_LEN": lambda: maybe_convert_int(
os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)
),

View File

@ -6,7 +6,7 @@ from typing import Any, cast
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
@ -47,6 +47,7 @@ class InputPreprocessor:
self,
model_config: ModelConfig,
tokenizer: TokenizerLike | None,
observability_config: ObservabilityConfig | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None:
@ -54,6 +55,7 @@ class InputPreprocessor:
self.model_config = model_config
self.tokenizer = tokenizer
self.observability_config = observability_config
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
@ -232,6 +234,7 @@ class InputPreprocessor:
if not hasattr(self, "_mm_processor"):
self._mm_processor = self.mm_registry.create_processor(
self.model_config,
self.observability_config,
tokenizer=self.tokenizer,
cache=self.mm_processor_cache,
)

View File

@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextvars
import threading
import time
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
from contextlib import contextmanager
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import lru_cache
@ -22,6 +25,7 @@ import regex as re
import torch
from typing_extensions import TypeVar, assert_never
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config
@ -53,7 +57,7 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import ModelConfig
from vllm.config import ModelConfig, ObservabilityConfig
from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder
@ -63,6 +67,7 @@ else:
ProcessorMixin = object
ModelConfig = object
ObservabilityConfig = object
BaseMultiModalProcessorCache = object
@ -70,6 +75,125 @@ logger = init_logger(__name__)
_S = TypeVar("_S", str, list[int])
# Context variable to store the current request_id during preprocessing
_request_id_context: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_request_id_context", default=None
)
def get_current_request_id() -> str | None:
"""Get the current request_id from the context, if available."""
return _request_id_context.get()
@contextmanager
def set_request_id(request_id: str) -> Generator[None, None, None]:
"""Context manager to set the request_id for the current context."""
token = _request_id_context.set(request_id)
try:
yield
finally:
_request_id_context.reset(token)
@dataclass
class MultiModalProcessorTimingStats:
"""Per-request timing statistics for multimodal processor stages."""
hf_processor_time: float = 0.0
"""Time spent in HuggingFace processor calls (seconds)."""
hashing_time: float = 0.0
"""Time spent computing multimodal item hashes (seconds)."""
cache_lookup_time: float = 0.0
"""Time spent in cache lookups and merges (seconds)."""
prompt_update_time: float = 0.0
"""Time spent applying prompt updates and finding placeholders (seconds)."""
total_time: float = 0.0
"""Total processing time (seconds)."""
def to_dict(self) -> dict[str, float]:
"""Convert stats to a dictionary for JSON serialization."""
return {
"hf_processor_time": self.hf_processor_time,
"hashing_time": self.hashing_time,
"cache_lookup_time": self.cache_lookup_time,
"prompt_update_time": self.prompt_update_time,
"total_time": self.total_time,
}
def get_timing_stats_from_engine_client(engine_client: Any) -> dict[str, dict[str, float]]:
"""
Get all timing stats from the context associated with the engine client.
Args:
engine_client: The engine client that has input_processor.
Returns:
A dictionary mapping request_id to stats dict.
"""
try:
if not engine_client.vllm_config.observability_config.enable_mm_processor_stats:
return {}
except (AttributeError, RuntimeError):
return {}
try:
input_processor = engine_client.input_processor
input_preprocessor = input_processor.input_preprocessor
if hasattr(input_preprocessor, "_get_mm_processor"):
mm_processor = input_preprocessor._get_mm_processor()
if mm_processor is not None and hasattr(mm_processor, "info"):
ctx = mm_processor.info.ctx
return ctx.get_all_timing_stats()
except (AttributeError, RuntimeError):
pass
return {}
@contextmanager
def _timed_operation(ctx: "InputProcessingContext", stage_name: str):
"""
Context manager to time an operation using the context's timing stats.
The request_id is automatically retrieved from the context variable,
so it doesn't need to be passed as a parameter.
Args:
ctx: The InputProcessingContext containing the timing stats registry.
stage_name: Name of the stage being timed.
"""
request_id = get_current_request_id()
if ctx is None or request_id is None:
yield
return
stats = ctx.get_timing_stats(request_id)
if stats is None:
yield
return
start_time = time.perf_counter()
try:
yield
finally:
elapsed = time.perf_counter() - start_time
if stage_name == "hf_processor":
stats.hf_processor_time += elapsed
elif stage_name == "hashing":
stats.hashing_time += elapsed
elif stage_name == "cache_lookup":
stats.cache_lookup_time += elapsed
elif stage_name == "prompt_update":
stats.prompt_update_time += elapsed
PromptSeq: TypeAlias = str | list[int]
"""A token sequence (list of token IDs) or text."""
@ -951,6 +1075,21 @@ class InputProcessingContext:
tokenizer: TokenizerLike | None
"""The tokenizer used to tokenize the inputs."""
observability_config: "ObservabilityConfig | None" = field(
default=None, compare=False, repr=False
)
"""Configuration for observability features."""
timing_stats_registry: dict[str, MultiModalProcessorTimingStats] = field(
default_factory=dict, compare=False, repr=False
)
"""Registry for storing timing stats keyed by request_id."""
_timing_stats_registry_lock: threading.Lock = field(
default_factory=threading.Lock, compare=False, repr=False
)
"""Lock for thread-safe access to timing_stats_registry."""
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
@ -1159,6 +1298,71 @@ class InputProcessingContext:
return self._postprocess_output(output)
def get_timing_stats(
self, request_id: str
) -> MultiModalProcessorTimingStats | None:
"""
Get timing stats for a request.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return None
with self._timing_stats_registry_lock:
return self.timing_stats_registry.get(request_id)
def create_timing_stats(self, request_id: str) -> MultiModalProcessorTimingStats:
"""
Create and store timing stats in the registry for a request.
This should be called at the start of processing for a request.
The stats object is created immediately and stored in the registry.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return MultiModalProcessorTimingStats()
with self._timing_stats_registry_lock:
if request_id in self.timing_stats_registry:
raise ValueError(
f"Timing stats already exist for request_id: {request_id}"
)
stats = MultiModalProcessorTimingStats()
self.timing_stats_registry[request_id] = stats
return stats
def clear_timing_stats_registry(self) -> int:
"""
Clear all stats from the registry. Returns the number of stats cleared.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return 0
with self._timing_stats_registry_lock:
count = len(self.timing_stats_registry)
self.timing_stats_registry.clear()
return count
def get_all_timing_stats(self) -> dict[str, dict[str, float]]:
"""
Get all timing stats as a dictionary for API endpoints.
"""
if (
self.observability_config is None
or not self.observability_config.enable_mm_processor_stats
):
return {}
with self._timing_stats_registry_lock:
return {
rid: stats.to_dict()
for rid, stats in self.timing_stats_registry.items()
}
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
@ -1494,11 +1698,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Call the HF processor on the prompt text and
associated multi-modal data.
"""
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
with _timed_operation(self.info.ctx, "hf_processor"):
return self.info.ctx.call_hf_processor(
self.info.get_hf_processor(**mm_kwargs),
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
def _hf_processor_applies_updates(
self,
@ -1846,12 +2051,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = self._hash_mm_items(
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
with _timed_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items(
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items,
@ -1892,18 +2098,20 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_uuids=mm_uuids,
)
mm_hashes = self._hash_mm_items(
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
with _timed_operation(self.info.ctx, "hashing"):
mm_hashes = self._hash_mm_items(
mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids,
)
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_hashes=mm_hashes,
)
with _timed_operation(self.info.ctx, "cache_lookup"):
mm_is_cached, mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_hashes=mm_hashes,
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
@ -1933,13 +2141,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_kwargs,
)
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
mm_is_cached=mm_is_cached,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates=mm_missing_prompt_updates,
)
with _timed_operation(self.info.ctx, "cache_lookup"):
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
mm_is_cached=mm_is_cached,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates=mm_missing_prompt_updates,
)
mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs,
@ -2121,6 +2330,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
request_id = get_current_request_id()
if request_id is not None:
self.info.ctx.create_timing_stats(request_id)
stats = (
self.info.ctx.get_timing_stats(request_id)
if request_id is not None
else None
)
mm_items = self._to_mm_items(mm_data)
if tokenization_kwargs is None:
@ -2139,13 +2357,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
# NOTE: tokenization_kwargs are not required to init processor
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
is_update_applied=is_update_applied,
)
with _timed_operation(self.info.ctx, "prompt_update"):
prompt_ids, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
prompt_ids=prompt_ids,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
is_update_applied=is_update_applied,
)
mm_placeholder_ranges = {
modality: [item.to_range() for item in placeholders]

View File

@ -5,6 +5,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
@ -22,7 +23,7 @@ from .profiling import (
)
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__)
@ -148,6 +149,7 @@ class MultiModalRegistry:
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
@ -156,7 +158,11 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, cache=cache)
if observability_config is None:
observability_config = ObservabilityConfig()
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
@ -174,6 +180,7 @@ class MultiModalRegistry:
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
@ -182,7 +189,11 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, cache=cache)
if observability_config is None:
observability_config = ObservabilityConfig()
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
@ -231,27 +242,32 @@ class MultiModalRegistry:
def _create_processing_ctx(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer)
return InputProcessingContext(
model_config, tokenizer, observability_config=observability_config
)
def _create_processing_info(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
return factories.info(ctx)
def create_processor(
self,
model_config: "ModelConfig",
observability_config: "ObservabilityConfig | None" = None,
*,
tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None,
@ -265,7 +281,7 @@ class MultiModalRegistry:
model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory
ctx = self._create_processing_ctx(model_config, tokenizer)
ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
return factories.build_processor(ctx, cache=cache)
@ -276,13 +292,18 @@ class MultiModalRegistry:
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyDecoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(model_config, cache=cache)
if observability_config is None:
observability_config = ObservabilityConfig()
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.
@ -309,13 +330,18 @@ class MultiModalRegistry:
mm_counts: Mapping[str, int] | None = None,
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by `model_config`.
"""
processor = self.create_processor(model_config, cache=cache)
if observability_config is None:
observability_config = ObservabilityConfig()
processor = self.create_processor(
model_config, observability_config, cache=cache
)
profiler: MultiModalProfiler = MultiModalProfiler(processor)
# Extract configurable options from multimodal config.

View File

@ -15,7 +15,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.processing import EncDecMultiModalProcessor, set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
@ -57,6 +57,7 @@ class InputProcessor:
self.input_preprocessor = InputPreprocessor(
self.model_config,
tokenizer,
self.vllm_config.observability_config,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
@ -445,11 +446,13 @@ class InputProcessor:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
with set_request_id(request_id):
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
mm_uuids=mm_uuids,
)
from vllm.platforms import current_platform
current_platform.validate_request(
@ -590,6 +593,7 @@ class InputProcessor:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
model_config,
self.vllm_config.observability_config,
tokenizer=tokenizer,
)
assert isinstance(mm_processor, EncDecMultiModalProcessor)