[Model Runner V2] Refactor CudaGraphManager (#29583)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-11-26 21:37:59 -08:00 committed by GitHub
parent ecb1952378
commit 11ea5ec1ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch from collections.abc import Callable, Iterable
from typing import Any
import numpy as np import numpy as np
import torch import torch
@ -32,6 +33,7 @@ class CudaGraphManager:
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None assert self.compilation_config is not None
@ -40,102 +42,60 @@ class CudaGraphManager:
self.cudagraph_mode = CUDAGraphMode.NONE self.cudagraph_mode = CUDAGraphMode.NONE
else: else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.compilation_config.cudagraph_capture_sizes is not None: self.cudagraph_sizes = get_cudagraph_sizes(
cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes) self.compilation_config.cudagraph_capture_sizes,
# Limit the cudagraph sizes to the max decode batch size. self.max_num_reqs,
self.cudagraph_sizes = [ self.max_num_tokens,
x for x in cudagraph_sizes if x <= self.max_num_reqs self.cudagraph_mode,
] )
else:
self.cudagraph_sizes = []
self.padded_sizes = self._init_padded_sizes()
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle() self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None self.hidden_states: torch.Tensor | None = None
def _init_padded_sizes(self) -> dict[int, int]:
if not self.cudagraph_mode.has_full_cudagraphs():
# Full cuda graphs are not used.
return {}
if not self.cudagraph_sizes:
return {}
padded_sizes: dict[int, int] = {}
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
padded_sizes[i] = x
break
return padded_sizes
def needs_capture(self) -> bool: def needs_capture(self) -> bool:
return len(self.padded_sizes) > 0 return len(self.cudagraph_sizes) > 0
def get_cudagraph_size( def get_cudagraph_size(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
num_tokens_after_padding: int, num_tokens_after_padding: int,
) -> int | None: ) -> int | None:
if not self.cudagraph_mode.has_full_cudagraphs(): return get_cudagraph_size(
return None num_tokens_after_padding,
if self.cudagraph_mode != CUDAGraphMode.FULL: scheduler_output.num_scheduled_tokens.values(),
# TODO(woosuk): Support uniform decode with multiple tokens (spec decoding). self.cudagraph_sizes,
all_decode = all( self.cudagraph_mode,
x == 1 for x in scheduler_output.num_scheduled_tokens.values() )
)
if not all_decode:
# Prefill is included.
return None
return self.padded_sizes.get(num_tokens_after_padding)
def capture_graph( def capture_graph(
self, self,
batch_size: int, num_tokens: int,
model: nn.Module, model: nn.Module,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> None: ) -> None:
assert batch_size not in self.graphs num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids.gpu[:num_tokens]
# Prepare dummy inputs. positions = input_buffers.positions[:num_tokens]
input_ids = input_buffers.input_ids.gpu[:batch_size] attn_metadata = prepare_inputs_to_capture(
positions = input_buffers.positions[:batch_size] num_reqs,
num_tokens,
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1) input_buffers,
input_buffers.query_start_loc.np[batch_size:] = batch_size block_tables,
input_buffers.query_start_loc.copy_to_gpu() attn_metadata_builders,
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len) self.max_model_len,
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and kv_cache_config,
# seq_lens_np (CPU), which might cause issues in some attention backends.
input_buffers.seq_lens[:batch_size] = 1
input_buffers.seq_lens[batch_size:] = 0
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=batch_size,
num_tokens=batch_size,
query_start_loc_gpu=input_buffers.query_start_loc.gpu[: batch_size + 1],
query_start_loc_cpu=input_buffers.query_start_loc.cpu[: batch_size + 1],
seq_lens=input_buffers.seq_lens,
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, batch_size) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up. # Warm up.
with set_forward_context( with set_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=batch_size, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
): ):
@ -147,13 +107,13 @@ class CudaGraphManager:
self.hidden_states = torch.empty_like(hidden_states) self.hidden_states = torch.empty_like(hidden_states)
# Capture the graph. # Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with ( with (
patch("torch.cuda.empty_cache", lambda: None),
set_forward_context( set_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=batch_size, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
), ),
@ -163,8 +123,8 @@ class CudaGraphManager:
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
) )
self.hidden_states[:batch_size] = hidden_states self.hidden_states[:num_tokens] = hidden_states
self.graphs[batch_size] = graph self.graphs[num_tokens] = graph
@torch.inference_mode() @torch.inference_mode()
def capture( def capture(
@ -175,25 +135,124 @@ class CudaGraphManager:
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> None: ) -> None:
assert self.needs_capture() capture_graphs(
# Capture larger graphs first. self.cudagraph_sizes,
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True) self.device,
if is_global_first_rank(): self.capture_graph,
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") model=model,
input_buffers=input_buffers,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
)
with graph_capture(device=self.device): def run(self, num_tokens: int) -> torch.Tensor:
for batch_size in sizes_to_capture: assert num_tokens in self.graphs
self.capture_graph( self.graphs[num_tokens].replay()
batch_size,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)
def run(self, batch_size: int) -> torch.Tensor:
assert batch_size in self.graphs
self.graphs[batch_size].replay()
assert self.hidden_states is not None assert self.hidden_states is not None
return self.hidden_states[:batch_size] return self.hidden_states[:num_tokens]
def get_cudagraph_sizes(
capture_sizes: list[int] | None,
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
) -> dict[int, int]:
if not cudagraph_mode.has_full_cudagraphs():
return {}
if not capture_sizes:
return {}
capture_sizes = sorted(capture_sizes)
# Limit the capture sizes to the max number of requests or tokens.
upper_bound = (
max_num_reqs
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
else max_num_tokens
)
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
if not capture_sizes:
return {}
cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
return cudagraph_sizes
def get_cudagraph_size(
num_tokens_after_dp_padding: int,
num_tokens_per_request: Iterable[int],
cudagraph_sizes: dict[int, int],
cudagraph_mode: CUDAGraphMode,
) -> int | None:
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
if size is None:
# No CUDA graph for this size.
return None
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
all_decode = all(x == 1 for x in num_tokens_per_request)
if not all_decode:
# Prefill is included.
return None
return size
def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, **capture_kwargs)
def prepare_inputs_to_capture(
num_reqs: int,
num_tokens: int,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc = input_buffers.query_start_loc
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
query_start_loc.np[num_reqs:] = num_tokens
query_start_loc.copy_to_gpu()
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
# seq_lens_np (CPU), which might cause issues in some attention backends.
input_buffers.seq_lens[:num_reqs] = 1
input_buffers.seq_lens[num_reqs:] = 0
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
seq_lens=input_buffers.seq_lens,
seq_lens_np=seq_lens_np,
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
return attn_metadata