Add logging for cudagraph related info (#29825)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-12-02 23:01:48 -10:00 committed by GitHub
parent 3a7751485b
commit 69520bc695
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 161 additions and 6 deletions

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
@ -22,6 +23,99 @@ from vllm.utils.torch_utils import weak_ref_tensors
logger = init_logger(__name__)
@dataclasses.dataclass(frozen=True)
class CUDAGraphStat:
num_unpadded_tokens: int
num_padded_tokens: int
num_paddings: int
runtime_mode: str
class CUDAGraphLogging:
"""Aggregate and log cudagraph metrics"""
COLUMN_HEADERS = [
"Unpadded Tokens",
"Padded Tokens",
"Num Paddings",
"Runtime Mode",
"Count",
]
def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
self.reset()
self.cg_mode = str(cg_mode)
self.cg_capture_sizes = str(cg_capture_sizes or [])
self.settings_header = (
"**CUDAGraph Config Settings:**\n\n"
f"- Mode: {self.cg_mode}\n"
f"- Capture sizes: {self.cg_capture_sizes}\n\n"
"**CUDAGraph Stats:**\n\n"
)
def reset(self):
self.stats = []
def observe(self, cudagraph_stat: CUDAGraphStat):
self.stats.append(cudagraph_stat)
def generate_metric_table(self) -> str:
stats_counts = Counter(self.stats)
# Convert stats to rows of strings, in descending order of observed frequencies
rows = []
for stat, count in sorted(
stats_counts.items(), key=lambda item: item[1], reverse=True
):
rows.append(
[
str(stat.num_unpadded_tokens),
str(stat.num_padded_tokens),
str(stat.num_paddings),
stat.runtime_mode,
str(count),
]
)
# Calculate column widths (max of header and data)
col_widths = []
for i, header_text in enumerate(self.COLUMN_HEADERS):
max_width = len(header_text)
for row in rows:
max_width = max(max_width, len(row[i]))
col_widths.append(max_width)
table_header_list = [
h.ljust(w) for h, w in zip(self.COLUMN_HEADERS, col_widths)
]
table_header = "| " + " | ".join(table_header_list) + " |\n"
table_separator = "|" + "|".join("-" * (w + 2) for w in col_widths) + "|\n"
# Create data rows with proper alignment
data_rows = []
for row in rows:
formatted_row = [
str(val).ljust(width) for val, width in zip(row, col_widths)
]
data_rows.append("| " + " | ".join(formatted_row) + " |")
return (
self.settings_header
+ table_header
+ table_separator
+ "\n".join(data_rows)
+ "\n"
)
def log(self, log_fn=logger.info):
if not self.stats:
return
log_fn(self.generate_metric_table())
self.reset()
@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor

View File

@ -55,6 +55,10 @@ class ObservabilityConfig:
kv_cache_metrics_sample: float = Field(default=0.01, gt=0, le=1)
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
cudagraph_metrics: bool = False
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
dispatch modes, and their observed frequencies at every logging interval)."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""

View File

@ -518,6 +518,7 @@ class EngineArgs:
kv_cache_metrics_sample: float = get_field(
ObservabilityConfig, "kv_cache_metrics_sample"
)
cudagraph_metrics: bool = ObservabilityConfig.cudagraph_metrics
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: str | type[object] | None = SchedulerConfig.scheduler_cls
@ -1021,6 +1022,10 @@ class EngineArgs:
"--kv-cache-metrics-sample",
**observability_kwargs["kv_cache_metrics_sample"],
)
observability_group.add_argument(
"--cudagraph-metrics",
**observability_kwargs["cudagraph_metrics"],
)
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
@ -1698,6 +1703,7 @@ class EngineArgs:
collect_detailed_traces=self.collect_detailed_traces,
kv_cache_metrics=self.kv_cache_metrics,
kv_cache_metrics_sample=self.kv_cache_metrics_sample,
cudagraph_metrics=self.cudagraph_metrics,
)
# Compilation config overrides

View File

@ -7,6 +7,7 @@ from collections.abc import Iterable
from typing import Any
from vllm import envs
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorMetadata,
@ -1037,6 +1038,7 @@ class Scheduler(SchedulerInterface):
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
kv_connector_output = model_runner_output.kv_connector_output
cudagraph_stats = model_runner_output.cudagraph_stats
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
@ -1219,7 +1221,9 @@ class Scheduler(SchedulerInterface):
finished_req_ids.clear()
if (
stats := self.make_stats(spec_decoding_stats, kv_connector_stats)
stats := self.make_stats(
spec_decoding_stats, kv_connector_stats, cudagraph_stats
)
) is not None:
# Return stats to only one of the front-ends.
if (eco := next(iter(engine_core_outputs.values()), None)) is None:
@ -1420,6 +1424,7 @@ class Scheduler(SchedulerInterface):
self,
spec_decoding_stats: SpecDecodingStats | None = None,
kv_connector_stats: KVConnectorStats | None = None,
cudagraph_stats: CUDAGraphStat | None = None,
) -> SchedulerStats | None:
if not self.log_stats:
return None
@ -1444,6 +1449,7 @@ class Scheduler(SchedulerInterface):
kv_cache_eviction_events=eviction_events,
spec_decoding_stats=spec_stats,
kv_connector_stats=connector_stats_payload,
cudagraph_stats=cudagraph_stats,
)
def make_spec_decoding_stats(

View File

@ -10,6 +10,7 @@ from typing import TypeAlias
from prometheus_client import Counter, Gauge, Histogram
import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphLogging
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorLogging,
@ -106,6 +107,12 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging = SpecDecodingLogging()
kv_transfer_config = self.vllm_config.kv_transfer_config
self.kv_connector_logging = KVConnectorLogging(kv_transfer_config)
self.cudagraph_logging = None
if self.vllm_config.observability_config.cudagraph_metrics:
self.cudagraph_logging = CUDAGraphLogging(
self.vllm_config.compilation_config.cudagraph_mode,
self.vllm_config.compilation_config.cudagraph_capture_sizes,
)
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
self.engine_is_idle = False
@ -161,6 +168,11 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.observe(scheduler_stats.spec_decoding_stats)
if kv_connector_stats := scheduler_stats.kv_connector_stats:
self.kv_connector_logging.observe(kv_connector_stats)
if (
self.cudagraph_logging is not None
and scheduler_stats.cudagraph_stats is not None
):
self.cudagraph_logging.observe(scheduler_stats.cudagraph_stats)
if not self.aggregated:
self.last_scheduler_stats = scheduler_stats
if mm_cache_stats:
@ -240,6 +252,8 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging.log(log_fn=log_fn)
self.kv_connector_logging.log(log_fn=log_fn)
if self.cudagraph_logging is not None:
self.cudagraph_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:

View File

@ -7,6 +7,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import vllm.envs as envs
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.v1.spec_decode.metrics import SpecDecodingStats
if TYPE_CHECKING:
@ -183,6 +184,8 @@ class SchedulerStats:
waiting_lora_adapters: dict[str, int] = field(default_factory=dict)
running_lora_adapters: dict[str, int] = field(default_factory=dict)
cudagraph_stats: CUDAGraphStat | None = None
@dataclass
class RequestStateStats:

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple
import numpy as np
import torch
from vllm.compilation.cuda_graph import CUDAGraphStat
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
@ -169,6 +170,9 @@ class ModelRunnerOutput:
# req_id -> num_nans_in_logits
num_nans_in_logits: dict[str, int] | None = None
# information related to cudagraph execution
cudagraph_stats: CUDAGraphStat | None = None
# ModelRunnerOutput wrapper for async scheduling.
class AsyncModelRunnerOutput(ABC):

View File

@ -27,7 +27,7 @@ from vllm.attention.backends.abstract import (
)
from vllm.attention.layer import Attention, MLAAttention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.cuda_graph import CUDAGraphStat, CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
from vllm.config import (
CompilationMode,
@ -257,6 +257,7 @@ class ExecuteModelState(NamedTuple):
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
ec_connector_output: ECConnectorOutput | None
cudagraph_stats: CUDAGraphStat | None
class GPUModelRunner(
@ -2755,7 +2756,11 @@ class GPUModelRunner(
force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None,
) -> tuple[
CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
CUDAGraphMode,
BatchDescriptor,
UBatchSlices | None,
torch.Tensor | None,
CUDAGraphStat | None,
]:
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
uniform_decode = (
@ -2820,7 +2825,22 @@ class GPUModelRunner(
# num_tokens_across_dp will no-longer be valid
assert batch_descriptor.num_tokens == num_tokens_padded
return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
cudagraph_stats = None
if self.vllm_config.observability_config.cudagraph_metrics:
cudagraph_stats = CUDAGraphStat(
num_unpadded_tokens=num_tokens,
num_padded_tokens=batch_descriptor.num_tokens,
num_paddings=batch_descriptor.num_tokens - num_tokens,
runtime_mode=str(cudagraph_mode),
)
return (
cudagraph_mode,
batch_descriptor,
ubatch_slices,
num_tokens_across_dp,
cudagraph_stats,
)
@torch.inference_mode()
def execute_model(
@ -2918,6 +2938,7 @@ class GPUModelRunner(
batch_desc,
ubatch_slices,
num_tokens_across_dp,
cudagraph_stats,
) = self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,
@ -3067,6 +3088,7 @@ class GPUModelRunner(
sample_hidden_states,
aux_hidden_states,
ec_connector_output,
cudagraph_stats,
)
self.kv_connector_output = kv_connector_output
return None
@ -3102,6 +3124,7 @@ class GPUModelRunner(
sample_hidden_states,
aux_hidden_states,
ec_connector_output,
cudagraph_stats,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
@ -3217,6 +3240,7 @@ class GPUModelRunner(
if self.supports_mm_inputs
else None,
num_nans_in_logits=num_nans_in_logits,
cudagraph_stats=cudagraph_stats,
)
if not self.use_async_scheduling:
@ -3937,7 +3961,7 @@ class GPUModelRunner(
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp, _ = (
self._determine_batch_execution_and_padding(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs,

View File

@ -564,7 +564,7 @@ class Worker(WorkerBase):
# TODO(lucas): This is pretty gross; ideally we should only ever call
# `_determine_batch_execution_and_padding` once (will get called again
# in `execute_model`) but this requires a larger refactor of PP.
_, batch_desc, _, _ = (
_, batch_desc, _, _, _ = (
self.model_runner._determine_batch_execution_and_padding(
num_tokens=num_scheduled_tokens,
num_reqs=len(num_scheduled_tokens_np),