From da3222f371b48c8e2548ec22767523394580a1c5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 27 Nov 2025 00:09:41 -0800 Subject: [PATCH] [Model Runner V2] Implement multi-step Eagle with CUDA graph (#29559) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/cudagraph_utils.py | 9 +- vllm/v1/worker/gpu/model_runner.py | 53 +-- vllm/v1/worker/gpu/spec_decode/eagle.py | 422 ++++++++++++++++-- .../worker/gpu/spec_decode/eagle_cudagraph.py | 112 +++++ 4 files changed, 526 insertions(+), 70 deletions(-) create mode 100644 vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 8f1718e493b1e..4fd8eb50a4ea8 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -233,10 +233,11 @@ def prepare_inputs_to_capture( query_start_loc.np[num_reqs:] = num_tokens query_start_loc.copy_to_gpu() seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32) - # HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len) - # for seq_lens. This leads to a mismatch between seq_lens (GPU) and - # seq_lens_np (CPU), which might cause issues in some attention backends. - input_buffers.seq_lens[:num_reqs] = 1 + # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens + # rather than max_model_len. This introduces a discrepancy between + # seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for + # certain attention backends. + input_buffers.seq_lens[:num_reqs] = num_tokens input_buffers.seq_lens[num_reqs:] = 0 input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index ed41e5a1a6c5e..0c9fdd0077f4a 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) # CUDA graphs. - self.cudagraph_manager = CudaGraphManager( - vllm_config=self.vllm_config, - device=self.device, - ) + self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device) def get_supported_tasks(self) -> tuple[str]: return ("generate",) @@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.vllm_config, self.device, ) + if self.do_spec_decode: + # HACK(woosuk) + self.speculator.set_attn( + self.kv_cache_config, + self.attn_metadata_builders, + self.block_tables, + ) + # TODO(woosuk): Support other backends. if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()): raise NotImplementedError("Only FLASH_ATTN backend is supported currently.") @@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits = self.model.compute_logits(hidden_states) self.sampler(logits, sampling_metadata) - @torch.inference_mode() - def _dummy_speculator_run( - self, - hidden_states: torch.Tensor, - aux_hidden_states: list[torch.Tensor] | None, - ) -> None: - num_tokens = hidden_states.shape[0] - num_reqs = min(num_tokens, self.max_num_reqs) - input_batch = InputBatch.make_dummy( - num_reqs=num_reqs, - num_tokens=num_tokens, - input_buffers=self.input_buffers, - device=self.device, - ) - sampling_metadata = SamplingMetadata.make_dummy( - num_reqs=num_reqs, - device=self.device, - ) - num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device) - num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) - self.propose_draft( - input_batch=input_batch, - sampling_metadata=sampling_metadata, - last_hidden_states=hidden_states, - aux_hidden_states=aux_hidden_states, - num_sampled=num_sampled, - num_rejected=num_rejected, - ) - @torch.inference_mode() def profile_run(self) -> None: hidden_states, sample_hidden_states = self._dummy_run( @@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) self._dummy_sampler_run(sample_hidden_states) if self.do_spec_decode: - self._dummy_speculator_run(hidden_states, None) + num_tokens_across_dp = make_num_tokens_across_dp( + self.dp_size, self.max_num_tokens + ) + self.speculator.run_model( + self.max_num_tokens, + attn_metadata=None, + num_tokens_across_dp=num_tokens_across_dp, + ) torch.cuda.synchronize() del hidden_states, sample_hidden_states gc.collect() @@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata_builders=self.attn_metadata_builders, kv_cache_config=self.kv_cache_config, ) + if self.do_spec_decode: + self.speculator.capture_model() end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index 3c8621cc69c97..daf2775e8b92d 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -1,17 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import numpy as np import torch import torch.nn as nn from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode from vllm.forward_context import set_forward_context +from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.triton_utils import tl, triton -from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.utils.platform_utils import is_pin_memory_available +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.attn_utils import build_attn_metadata +from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.sampler import gumbel_sample +from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.states import SamplingMetadata +logger = init_logger(__name__) + class EagleSpeculator: def __init__(self, vllm_config: VllmConfig, device: torch.device): @@ -27,13 +39,48 @@ class EagleSpeculator: self.scheduler_config = vllm_config.scheduler_config self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_tokens = self.scheduler_config.max_num_batched_tokens + self.max_model_len = vllm_config.model_config.max_model_len + # We need to get the hidden size from the draft model config because + # the draft model's hidden size can be different from the target model's + # hidden size (e.g., Llama 3.3 70B). + self.hidden_size = self.draft_model_config.get_hidden_size() + self.vocab_size = self.draft_model_config.get_vocab_size() + self.pin_memory = is_pin_memory_available() + self.dtype = vllm_config.model_config.dtype - self.input_ids = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device=device + self.input_buffers = InputBuffers( + max_num_reqs=self.max_num_reqs, + max_num_tokens=self.max_num_tokens, + hidden_size=self.hidden_size, + vocab_size=self.vocab_size, + dtype=self.dtype, + device=device, + pin_memory=self.pin_memory, ) - self.positions = torch.zeros( - self.max_num_tokens, dtype=torch.int64, device=device + self.hidden_states = torch.zeros( + self.max_num_tokens, + self.hidden_size, + dtype=self.dtype, + device=device, ) + self.temperature = torch.zeros( + self.max_num_reqs, + dtype=torch.float32, + device=device, + ) + self.seeds = torch.zeros( + self.max_num_reqs, + dtype=torch.int64, + device=device, + ) + self.draft_tokens = torch.zeros( + self.max_num_reqs, + self.num_speculative_steps, + dtype=torch.int64, + device=device, + ) + + self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device) def load_model(self, target_model: nn.Module) -> None: from vllm.compilation.backends import set_model_tag @@ -49,6 +96,91 @@ class EagleSpeculator: del self.model.lm_head self.model.lm_head = target_model.lm_head + def set_attn( + self, + kv_cache_config: KVCacheConfig, + attn_metadata_builders: list[AttentionMetadataBuilder], + block_tables: BlockTables, + ) -> None: + self.kv_cache_config = kv_cache_config + self.attn_metadata_builders = attn_metadata_builders + self.block_tables = block_tables + + @torch.inference_mode() + def run_model( + self, + num_tokens: int, + attn_metadata: dict[str, Any], + num_tokens_across_dp: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + num_tokens_across_dp=num_tokens_across_dp, + ): + ret_hidden_states = self.model( + input_ids=self.input_buffers.input_ids.gpu[:num_tokens], + positions=self.input_buffers.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + ) + if self.method == "mtp": + last_hidden_states = ret_hidden_states + hidden_states = ret_hidden_states + else: + last_hidden_states, hidden_states = ret_hidden_states + return last_hidden_states, hidden_states + + def generate_draft( + self, + num_reqs: int, + attn_metadata: dict[str, Any], + num_tokens_across_dp: torch.Tensor | None, + ) -> None: + pos = self.input_buffers.positions[:num_reqs] + query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1] + for step in range(1, self.num_speculative_steps): + # Run the eagle model. + last_hidden_states, hidden_states = self.run_model( + num_reqs, attn_metadata, num_tokens_across_dp + ) + logits = self.model.compute_logits(last_hidden_states) + + # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise + # used for draft and target sampling. + draft_tokens = gumbel_sample( + logits, + self.temperature[:num_reqs], + self.seeds[:num_reqs], + pos + 1, + apply_temperature=True, + ) + self.draft_tokens[:num_reqs, step] = draft_tokens + + if step < self.num_speculative_steps - 1: + # Update the inputs for the next step. + update_eagle_inputs( + draft_tokens, + hidden_states, + self.input_buffers, + self.hidden_states, + self.max_model_len, + ) + self.block_tables.compute_slot_mappings(query_start_loc, pos) + + def capture_model(self) -> None: + if self.num_speculative_steps == 1: + return + logger.info("Capturing model for Eagle speculator...") + self.cudagraph_manager.capture( + self.generate_draft, + self.input_buffers, + self.block_tables, + self.attn_metadata_builders, + self.kv_cache_config, + ) + @torch.inference_mode() def propose( self, @@ -80,64 +212,110 @@ class EagleSpeculator: ) else: hidden_states = last_hidden_states + num_tokens = input_batch.num_tokens_after_padding + self.hidden_states[:num_tokens] = hidden_states # Get the input ids and last token indices for the speculator. last_token_indices = prepare_eagle_inputs( - self.input_ids, + self.input_buffers, input_batch, num_sampled, num_rejected, last_sampled, next_prefill_tokens, ) - input_ids = self.input_ids[: input_batch.num_tokens_after_padding] # Prefill: Run the eagle speculator with eager mode. - with set_forward_context( + # TODO(woosuk): Support CUDA graph for prefill. + last_hidden_states, hidden_states = self.run_model( + num_tokens, input_batch.attn_metadata, - self.vllm_config, - num_tokens=input_batch.num_tokens_after_padding, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - ): - ret_hidden_states = self.model( - input_ids=input_ids, - positions=input_batch.positions, - hidden_states=hidden_states, - ) - if self.method == "mtp": - last_hidden_states = ret_hidden_states - hidden_states = ret_hidden_states - else: - last_hidden_states, hidden_states = ret_hidden_states + num_tokens_across_dp=None, # FIXME + ) sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) num_reqs = input_batch.num_reqs cu_num_logits = input_batch.cu_num_logits[:num_reqs] - temperature = sampling_metadata.temperature[cu_num_logits] - seed = sampling_metadata.seeds[cu_num_logits] - # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise - # used for draft and target sampling. - pos = input_batch.positions[last_token_indices] + 1 # NOTE(woosuk): For draft sampling, we only consider the temperature # and ignore the other sampling parameters such as top_k and top_p, # for simplicity and performance. # While this may slightly degrade the acceptance rate, it does not # affect the output distribution after rejection sampling. + temperature = self.temperature[:num_reqs] + seeds = self.seeds[:num_reqs] + pos = self.input_buffers.positions[:num_reqs] + # Gather the values and copy them to the pre-allocated buffers. + torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature) + torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds) + torch.gather(input_batch.positions, 0, last_token_indices, out=pos) + # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise + # used for draft and target sampling. draft_tokens = gumbel_sample( - logits, temperature, seed, pos, apply_temperature=True + logits, temperature, seeds, pos + 1, apply_temperature=True ) if self.num_speculative_steps == 1: # Early exit. return draft_tokens.view(-1, 1) - raise NotImplementedError("num_speculative_steps > 1 is not supported yet.") + + # Save the draft tokens for the first step. + self.draft_tokens[:num_reqs, 0] = draft_tokens + # Prepare the inputs for the decode steps. + prepare_eagle_decode( + draft_tokens, + hidden_states, + last_token_indices, + input_batch.seq_lens, + num_rejected, + self.input_buffers, + self.hidden_states, + self.max_model_len, + self.max_num_reqs, + ) + query_start_loc = self.input_buffers.query_start_loc + query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1] + slot_mappings = self.block_tables.compute_slot_mappings( + query_start_loc_gpu, pos + ) + + cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) + if cudagraph_size is not None: + # Run CUDA graph. + self.cudagraph_manager.run(cudagraph_size) + return self.draft_tokens[:num_reqs] + + # Run eager mode. + query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1) + query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1] + # HACK(woosuk) + seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32) + block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] + + # FIXME(woosuk): This is UNSAFE!! + attn_metadata = build_attn_metadata( + attn_metadata_builders=self.attn_metadata_builders, + num_reqs=num_reqs, + num_tokens=num_reqs, + query_start_loc_gpu=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=self.input_buffers.seq_lens[:num_reqs], + seq_lens_np=seq_lens_np, + num_computed_tokens_cpu=None, # FIXME + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME + return self.draft_tokens[:num_reqs] @triton.jit def _prepare_eagle_inputs_kernel( last_token_indices_ptr, eagle_input_ids_ptr, + eagle_positions_ptr, target_input_ids_ptr, + target_positions_ptr, idx_mapping_ptr, last_sampled_ptr, next_prefill_tokens_ptr, @@ -175,9 +353,16 @@ def _prepare_eagle_inputs_kernel( tl.store(last_token_indices_ptr + batch_idx, last_token_index) tl.store(eagle_input_ids_ptr + last_token_index, next_token) + # Copy positions. + for i in range(0, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask) + tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask) + def prepare_eagle_inputs( - eagle_input_ids: torch.Tensor, + input_buffers: InputBuffers, input_batch: InputBatch, # [num_reqs] num_sampled: torch.Tensor, @@ -192,12 +377,14 @@ def prepare_eagle_inputs( last_token_indices = torch.empty( num_reqs, dtype=torch.int64, - device=eagle_input_ids.device, + device=num_sampled.device, ) _prepare_eagle_inputs_kernel[(num_reqs,)]( last_token_indices, - eagle_input_ids, + input_buffers.input_ids.gpu, + input_buffers.positions, input_batch.input_ids, + input_batch.positions, input_batch.idx_mapping, last_sampled, next_prefill_tokens, @@ -207,3 +394,174 @@ def prepare_eagle_inputs( BLOCK_SIZE=1024, ) return last_token_indices + + +@triton.jit +def _prepare_eagle_docode_kernel( + draft_tokens_ptr, + output_hidden_states_ptr, + output_hidden_states_stride, + last_token_indices_ptr, + target_seq_lens_ptr, + num_rejected_ptr, + input_ids_ptr, + positions_ptr, + input_hidden_states_ptr, + input_hidden_states_stride, + query_start_loc_ptr, + seq_lens_ptr, + hidden_size, + max_model_len, + max_num_reqs, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + num_reqs = tl.num_programs(0) - 1 + if req_idx == num_reqs: + # Compute query_start_loc. Pad it with the last query_start_loc + # for CUDA graphs. + for i in range(0, max_num_reqs + 1, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + q = tl.where(block < num_reqs, block, num_reqs) + mask = block < max_num_reqs + 1 + tl.store(query_start_loc_ptr + block, q, mask=mask) + # Pad seq_lens for CUDA graphs. + for i in range(req_idx, max_num_reqs, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < max_num_reqs + tl.store(seq_lens_ptr + block, 0, mask=mask) + return + + # draft token -> input id. + draft_token = tl.load(draft_tokens_ptr + req_idx) + tl.store(input_ids_ptr + req_idx, draft_token) + + # output hidden states -> input hidden states. + src_idx = tl.load(last_token_indices_ptr + req_idx) + for i in range(0, hidden_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < hidden_size + output_hidden_states = tl.load( + output_hidden_states_ptr + src_idx * output_hidden_states_stride + block, + mask=mask, + ) + tl.store( + input_hidden_states_ptr + req_idx * input_hidden_states_stride + block, + output_hidden_states, + mask=mask, + ) + + # Compute position and seq_lens. + # NOTE(woosuk): To prevent out-of-range access, we clamp these values + # if they reach the max model length. + position = tl.load(positions_ptr + req_idx) + position = tl.minimum(position + 1, max_model_len - 1) + tl.store(positions_ptr + req_idx, position) + + target_seq_len = tl.load(target_seq_lens_ptr + req_idx) + num_rejected = tl.load(num_rejected_ptr + req_idx) + seq_len = target_seq_len - num_rejected + seq_len = tl.minimum(seq_len + 1, max_model_len) + tl.store(seq_lens_ptr + req_idx, seq_len) + + +def prepare_eagle_decode( + draft_tokens: torch.Tensor, + output_hidden_states: torch.Tensor, + last_token_indices: torch.Tensor, + target_seq_lens: torch.Tensor, + num_rejected: torch.Tensor, + input_buffers: InputBuffers, + input_hidden_states: torch.Tensor, + max_model_len: int, + max_num_reqs: int, +): + num_reqs = draft_tokens.shape[0] + hidden_size = output_hidden_states.shape[-1] + _prepare_eagle_docode_kernel[(num_reqs + 1,)]( + draft_tokens, + output_hidden_states, + output_hidden_states.stride(0), + last_token_indices, + target_seq_lens, + num_rejected, + input_buffers.input_ids.gpu, + input_buffers.positions, + input_hidden_states, + input_hidden_states.stride(0), + input_buffers.query_start_loc.gpu, + input_buffers.seq_lens, + hidden_size, + max_model_len, + max_num_reqs, + BLOCK_SIZE=1024, + ) + + +@triton.jit +def _update_eagle_inputs_kernel( + input_ids_ptr, + positions_ptr, + input_hidden_states_ptr, + input_hidden_states_stride, + seq_lens_ptr, + max_model_len, + draft_tokens_ptr, + output_hidden_states_ptr, + output_hidden_states_stride, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + + # Draft token -> Input ID. + draft_token = tl.load(draft_tokens_ptr + req_idx) + tl.store(input_ids_ptr + req_idx, draft_token) + + # Output hidden states -> Input hidden states. + for i in range(0, hidden_size, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < hidden_size + output_hidden_states = tl.load( + output_hidden_states_ptr + req_idx * output_hidden_states_stride + block, + mask=mask, + ) + tl.store( + input_hidden_states_ptr + req_idx * input_hidden_states_stride + block, + output_hidden_states, + mask=mask, + ) + + # Increment position and seq_lens. + # NOTE(woosuk): To prevent out-of-range access, we clamp these values + # if they reach the max model length. + position = tl.load(positions_ptr + req_idx) + position = tl.minimum(position + 1, max_model_len - 1) + tl.store(positions_ptr + req_idx, position) + + seq_len = tl.load(seq_lens_ptr + req_idx) + seq_len = tl.minimum(seq_len + 1, max_model_len) + tl.store(seq_lens_ptr + req_idx, seq_len) + + +def update_eagle_inputs( + draft_tokens: torch.Tensor, + output_hidden_states: torch.Tensor, + input_buffers: InputBuffers, + hidden_states: torch.Tensor, + max_model_len: int, +): + num_reqs, hidden_size = output_hidden_states.shape + _update_eagle_inputs_kernel[(num_reqs,)]( + input_buffers.input_ids.gpu, + input_buffers.positions, + hidden_states, + hidden_states.stride(0), + input_buffers.seq_lens, + max_model_len, + draft_tokens, + output_hidden_states, + output_hidden_states.stride(0), + hidden_size, + BLOCK_SIZE=1024, + ) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py new file mode 100644 index 0000000000000..a6f50d68cc684 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py @@ -0,0 +1,112 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import torch + +from vllm.config import VllmConfig +from vllm.config.compilation import CUDAGraphMode +from vllm.v1.attention.backends.utils import AttentionMetadataBuilder +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.cudagraph_utils import ( + capture_graphs, + get_cudagraph_sizes, + prepare_inputs_to_capture, +) +from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp +from vllm.v1.worker.gpu.input_batch import InputBuffers + + +class EagleCudaGraphManager: + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config + self.device = device + + self.max_model_len = vllm_config.model_config.max_model_len + self.max_num_reqs = self.scheduler_config.max_num_seqs + self.max_num_tokens = self.scheduler_config.max_num_batched_tokens + self.dp_size = vllm_config.parallel_config.data_parallel_size + self.compilation_config = vllm_config.compilation_config + assert self.compilation_config is not None + + if self.compilation_config.cudagraph_mode is None: + self.cudagraph_mode = CUDAGraphMode.NONE + else: + self.cudagraph_mode = self.compilation_config.cudagraph_mode + if self.cudagraph_mode == CUDAGraphMode.FULL: + # NOTE(woosuk): For Eagle, we only use CUDA graphs for decode. + self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY + + self.cudagraph_sizes = get_cudagraph_sizes( + self.compilation_config.cudagraph_capture_sizes, + self.max_num_reqs, + self.max_num_tokens, + self.cudagraph_mode, + ) + + self.graphs: dict[int, torch.cuda.CUDAGraph] = {} + self.pool = torch.cuda.graph_pool_handle() + + def get_cudagraph_size(self, num_tokens: int) -> int | None: + return self.cudagraph_sizes.get(num_tokens) + + def capture_graph( + self, + num_tokens: int, + generate_fn: Callable, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + kv_cache_config: KVCacheConfig, + ) -> None: + num_reqs = min(num_tokens, self.max_num_reqs) + attn_metadata = prepare_inputs_to_capture( + num_reqs, + num_tokens, + input_buffers, + block_tables, + attn_metadata_builders, + self.max_model_len, + kv_cache_config, + ) + num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) + + # Warm up. + generate_fn(num_tokens, attn_metadata, num_tokens_across_dp) + + # Capture the graph. + assert num_tokens not in self.graphs + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, self.pool): + generate_fn(num_tokens, attn_metadata, num_tokens_across_dp) + self.graphs[num_tokens] = graph + + @torch.inference_mode() + def capture( + self, + generate_fn: Callable, + input_buffers: InputBuffers, + block_tables: BlockTables, + attn_metadata_builders: list[AttentionMetadataBuilder], + kv_cache_config: KVCacheConfig, + ) -> None: + capture_graphs( + self.cudagraph_sizes, + self.device, + self.capture_graph, + generate_fn=generate_fn, + input_buffers=input_buffers, + block_tables=block_tables, + attn_metadata_builders=attn_metadata_builders, + kv_cache_config=kv_cache_config, + ) + + def run(self, num_tokens: int) -> None: + assert num_tokens in self.graphs + self.graphs[num_tokens].replay()