mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 16:05:43 +08:00
[Model Runner V2] Refactor CudaGraphManager (#29583)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
ecb1952378
commit
11ea5ec1ff
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from unittest.mock import patch
|
from collections.abc import Callable, Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -32,6 +33,7 @@ class CudaGraphManager:
|
|||||||
|
|
||||||
self.max_model_len = vllm_config.model_config.max_model_len
|
self.max_model_len = vllm_config.model_config.max_model_len
|
||||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
self.compilation_config = vllm_config.compilation_config
|
self.compilation_config = vllm_config.compilation_config
|
||||||
assert self.compilation_config is not None
|
assert self.compilation_config is not None
|
||||||
@ -40,102 +42,60 @@ class CudaGraphManager:
|
|||||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
else:
|
else:
|
||||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||||
if self.compilation_config.cudagraph_capture_sizes is not None:
|
self.cudagraph_sizes = get_cudagraph_sizes(
|
||||||
cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
|
self.compilation_config.cudagraph_capture_sizes,
|
||||||
# Limit the cudagraph sizes to the max decode batch size.
|
self.max_num_reqs,
|
||||||
self.cudagraph_sizes = [
|
self.max_num_tokens,
|
||||||
x for x in cudagraph_sizes if x <= self.max_num_reqs
|
self.cudagraph_mode,
|
||||||
]
|
)
|
||||||
else:
|
|
||||||
self.cudagraph_sizes = []
|
|
||||||
self.padded_sizes = self._init_padded_sizes()
|
|
||||||
|
|
||||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||||
self.pool = torch.cuda.graph_pool_handle()
|
self.pool = torch.cuda.graph_pool_handle()
|
||||||
self.hidden_states: torch.Tensor | None = None
|
self.hidden_states: torch.Tensor | None = None
|
||||||
|
|
||||||
def _init_padded_sizes(self) -> dict[int, int]:
|
|
||||||
if not self.cudagraph_mode.has_full_cudagraphs():
|
|
||||||
# Full cuda graphs are not used.
|
|
||||||
return {}
|
|
||||||
if not self.cudagraph_sizes:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
padded_sizes: dict[int, int] = {}
|
|
||||||
for i in range(1, self.cudagraph_sizes[-1] + 1):
|
|
||||||
for x in self.cudagraph_sizes:
|
|
||||||
if i <= x:
|
|
||||||
padded_sizes[i] = x
|
|
||||||
break
|
|
||||||
return padded_sizes
|
|
||||||
|
|
||||||
def needs_capture(self) -> bool:
|
def needs_capture(self) -> bool:
|
||||||
return len(self.padded_sizes) > 0
|
return len(self.cudagraph_sizes) > 0
|
||||||
|
|
||||||
def get_cudagraph_size(
|
def get_cudagraph_size(
|
||||||
self,
|
self,
|
||||||
scheduler_output: SchedulerOutput,
|
scheduler_output: SchedulerOutput,
|
||||||
num_tokens_after_padding: int,
|
num_tokens_after_padding: int,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
if not self.cudagraph_mode.has_full_cudagraphs():
|
return get_cudagraph_size(
|
||||||
return None
|
num_tokens_after_padding,
|
||||||
if self.cudagraph_mode != CUDAGraphMode.FULL:
|
scheduler_output.num_scheduled_tokens.values(),
|
||||||
# TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
|
self.cudagraph_sizes,
|
||||||
all_decode = all(
|
self.cudagraph_mode,
|
||||||
x == 1 for x in scheduler_output.num_scheduled_tokens.values()
|
)
|
||||||
)
|
|
||||||
if not all_decode:
|
|
||||||
# Prefill is included.
|
|
||||||
return None
|
|
||||||
return self.padded_sizes.get(num_tokens_after_padding)
|
|
||||||
|
|
||||||
def capture_graph(
|
def capture_graph(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
num_tokens: int,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
input_buffers: InputBuffers,
|
input_buffers: InputBuffers,
|
||||||
block_tables: BlockTables,
|
block_tables: BlockTables,
|
||||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert batch_size not in self.graphs
|
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||||
|
input_ids = input_buffers.input_ids.gpu[:num_tokens]
|
||||||
# Prepare dummy inputs.
|
positions = input_buffers.positions[:num_tokens]
|
||||||
input_ids = input_buffers.input_ids.gpu[:batch_size]
|
attn_metadata = prepare_inputs_to_capture(
|
||||||
positions = input_buffers.positions[:batch_size]
|
num_reqs,
|
||||||
|
num_tokens,
|
||||||
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
|
input_buffers,
|
||||||
input_buffers.query_start_loc.np[batch_size:] = batch_size
|
block_tables,
|
||||||
input_buffers.query_start_loc.copy_to_gpu()
|
attn_metadata_builders,
|
||||||
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
|
self.max_model_len,
|
||||||
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
|
kv_cache_config,
|
||||||
# seq_lens_np (CPU), which might cause issues in some attention backends.
|
|
||||||
input_buffers.seq_lens[:batch_size] = 1
|
|
||||||
input_buffers.seq_lens[batch_size:] = 0
|
|
||||||
|
|
||||||
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
|
|
||||||
slot_mappings = block_tables.slot_mappings[:, :batch_size]
|
|
||||||
|
|
||||||
attn_metadata = build_attn_metadata(
|
|
||||||
attn_metadata_builders=attn_metadata_builders,
|
|
||||||
num_reqs=batch_size,
|
|
||||||
num_tokens=batch_size,
|
|
||||||
query_start_loc_gpu=input_buffers.query_start_loc.gpu[: batch_size + 1],
|
|
||||||
query_start_loc_cpu=input_buffers.query_start_loc.cpu[: batch_size + 1],
|
|
||||||
seq_lens=input_buffers.seq_lens,
|
|
||||||
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
|
|
||||||
num_computed_tokens_cpu=None, # FIXME
|
|
||||||
block_tables=input_block_tables,
|
|
||||||
slot_mappings=slot_mappings,
|
|
||||||
kv_cache_config=kv_cache_config,
|
|
||||||
)
|
)
|
||||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, batch_size)
|
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||||
|
|
||||||
# Warm up.
|
# Warm up.
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=batch_size,
|
num_tokens=num_tokens,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
):
|
):
|
||||||
@ -147,13 +107,13 @@ class CudaGraphManager:
|
|||||||
self.hidden_states = torch.empty_like(hidden_states)
|
self.hidden_states = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
# Capture the graph.
|
# Capture the graph.
|
||||||
|
assert num_tokens not in self.graphs
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with (
|
with (
|
||||||
patch("torch.cuda.empty_cache", lambda: None),
|
|
||||||
set_forward_context(
|
set_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=batch_size,
|
num_tokens=num_tokens,
|
||||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
),
|
),
|
||||||
@ -163,8 +123,8 @@ class CudaGraphManager:
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
)
|
)
|
||||||
self.hidden_states[:batch_size] = hidden_states
|
self.hidden_states[:num_tokens] = hidden_states
|
||||||
self.graphs[batch_size] = graph
|
self.graphs[num_tokens] = graph
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def capture(
|
def capture(
|
||||||
@ -175,25 +135,124 @@ class CudaGraphManager:
|
|||||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||||
kv_cache_config: KVCacheConfig,
|
kv_cache_config: KVCacheConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.needs_capture()
|
capture_graphs(
|
||||||
# Capture larger graphs first.
|
self.cudagraph_sizes,
|
||||||
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
|
self.device,
|
||||||
if is_global_first_rank():
|
self.capture_graph,
|
||||||
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
model=model,
|
||||||
|
input_buffers=input_buffers,
|
||||||
|
block_tables=block_tables,
|
||||||
|
attn_metadata_builders=attn_metadata_builders,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
)
|
||||||
|
|
||||||
with graph_capture(device=self.device):
|
def run(self, num_tokens: int) -> torch.Tensor:
|
||||||
for batch_size in sizes_to_capture:
|
assert num_tokens in self.graphs
|
||||||
self.capture_graph(
|
self.graphs[num_tokens].replay()
|
||||||
batch_size,
|
|
||||||
model,
|
|
||||||
input_buffers,
|
|
||||||
block_tables,
|
|
||||||
attn_metadata_builders,
|
|
||||||
kv_cache_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(self, batch_size: int) -> torch.Tensor:
|
|
||||||
assert batch_size in self.graphs
|
|
||||||
self.graphs[batch_size].replay()
|
|
||||||
assert self.hidden_states is not None
|
assert self.hidden_states is not None
|
||||||
return self.hidden_states[:batch_size]
|
return self.hidden_states[:num_tokens]
|
||||||
|
|
||||||
|
|
||||||
|
def get_cudagraph_sizes(
|
||||||
|
capture_sizes: list[int] | None,
|
||||||
|
max_num_reqs: int,
|
||||||
|
max_num_tokens: int,
|
||||||
|
cudagraph_mode: CUDAGraphMode,
|
||||||
|
) -> dict[int, int]:
|
||||||
|
if not cudagraph_mode.has_full_cudagraphs():
|
||||||
|
return {}
|
||||||
|
if not capture_sizes:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
capture_sizes = sorted(capture_sizes)
|
||||||
|
# Limit the capture sizes to the max number of requests or tokens.
|
||||||
|
upper_bound = (
|
||||||
|
max_num_reqs
|
||||||
|
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
||||||
|
else max_num_tokens
|
||||||
|
)
|
||||||
|
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
|
||||||
|
if not capture_sizes:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
cudagraph_sizes: dict[int, int] = {}
|
||||||
|
for i in range(1, capture_sizes[-1] + 1):
|
||||||
|
for x in capture_sizes:
|
||||||
|
if i <= x:
|
||||||
|
cudagraph_sizes[i] = x
|
||||||
|
break
|
||||||
|
return cudagraph_sizes
|
||||||
|
|
||||||
|
|
||||||
|
def get_cudagraph_size(
|
||||||
|
num_tokens_after_dp_padding: int,
|
||||||
|
num_tokens_per_request: Iterable[int],
|
||||||
|
cudagraph_sizes: dict[int, int],
|
||||||
|
cudagraph_mode: CUDAGraphMode,
|
||||||
|
) -> int | None:
|
||||||
|
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
|
||||||
|
if size is None:
|
||||||
|
# No CUDA graph for this size.
|
||||||
|
return None
|
||||||
|
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
|
||||||
|
all_decode = all(x == 1 for x in num_tokens_per_request)
|
||||||
|
if not all_decode:
|
||||||
|
# Prefill is included.
|
||||||
|
return None
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
def capture_graphs(
|
||||||
|
cudagraph_sizes: dict[int, int],
|
||||||
|
device: torch.device,
|
||||||
|
capture_fn: Callable,
|
||||||
|
**capture_kwargs,
|
||||||
|
) -> None:
|
||||||
|
# Capture larger graphs first.
|
||||||
|
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
|
||||||
|
if is_global_first_rank():
|
||||||
|
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
||||||
|
|
||||||
|
with graph_capture(device=device):
|
||||||
|
for size in sizes_to_capture:
|
||||||
|
capture_fn(size, **capture_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_inputs_to_capture(
|
||||||
|
num_reqs: int,
|
||||||
|
num_tokens: int,
|
||||||
|
input_buffers: InputBuffers,
|
||||||
|
block_tables: BlockTables,
|
||||||
|
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||||
|
max_model_len: int,
|
||||||
|
kv_cache_config: KVCacheConfig,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
num_tokens_per_req = num_tokens // num_reqs
|
||||||
|
query_start_loc = input_buffers.query_start_loc
|
||||||
|
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) * num_tokens_per_req
|
||||||
|
query_start_loc.np[num_reqs:] = num_tokens
|
||||||
|
query_start_loc.copy_to_gpu()
|
||||||
|
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
|
||||||
|
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
|
||||||
|
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
|
||||||
|
# seq_lens_np (CPU), which might cause issues in some attention backends.
|
||||||
|
input_buffers.seq_lens[:num_reqs] = 1
|
||||||
|
input_buffers.seq_lens[num_reqs:] = 0
|
||||||
|
|
||||||
|
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||||
|
slot_mappings = block_tables.slot_mappings[:, :num_tokens]
|
||||||
|
|
||||||
|
attn_metadata = build_attn_metadata(
|
||||||
|
attn_metadata_builders=attn_metadata_builders,
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
query_start_loc_gpu=query_start_loc.gpu[: num_reqs + 1],
|
||||||
|
query_start_loc_cpu=query_start_loc.cpu[: num_reqs + 1],
|
||||||
|
seq_lens=input_buffers.seq_lens,
|
||||||
|
seq_lens_np=seq_lens_np,
|
||||||
|
num_computed_tokens_cpu=None, # FIXME
|
||||||
|
block_tables=input_block_tables,
|
||||||
|
slot_mappings=slot_mappings,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
)
|
||||||
|
return attn_metadata
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user