From 90d24dee04fbeaf30f066ecd201ff7e32153442e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 17 Sep 2025 20:48:14 +0000 Subject: [PATCH] enable piecewise cudagraphs for eagle Signed-off-by: Lucas Wilkinson --- vllm/forward_context.py | 7 ++++++ vllm/v1/spec_decode/eagle.py | 42 +++++++++++++++++++++++++----------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 3b535423f7bca..02430ad18c0e3 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -321,6 +321,13 @@ def set_forward_context( attn_metadata, num_tokens or 0, num_tokens_across_dp) + # Convienience: if cudagraph is used, and num_tokens is given, we can just + # create a batch descriptor here if not given (there's no harm since if it + # doesn't match in the wrapper it'll fall through). + if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: + batch_descriptor = batch_descriptor or BatchDescriptor( + num_tokens=num_tokens) + forward_context = create_forward_context(attn_metadata, vllm_config, virtual_engine, dp_metadata, cudagraph_runtime_mode, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5154b29405b6e..df60cb130a1b3 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context @@ -78,6 +78,11 @@ class EagleProposer: self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) + + self.cudagraph_runtime_mode = (CUDAGraphMode.PIECEWISE + if self.use_cuda_graph else + CUDAGraphMode.NONE) + self.cudagraph_batch_sizes = list( reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) @@ -212,9 +217,12 @@ class EagleProposer: inputs_embeds = None input_ids = self.input_ids[:num_input_tokens] - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=self.cudagraph_runtime_mode, + ): ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:num_input_tokens], @@ -322,9 +330,12 @@ class EagleProposer: input_ids = self.input_ids[:input_batch_size] # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=input_batch_size): + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size, + cudagraph_runtime_mode=self.cudagraph_runtime_mode, + ): ret_hidden_states = self.model( input_ids=input_ids, positions=self.positions[:input_batch_size], @@ -478,9 +489,12 @@ class EagleProposer: else: num_input_tokens = num_tokens # Run the model. - with set_forward_context(per_layer_attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens): + with set_forward_context( + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=self.cudagraph_runtime_mode, + ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], @@ -665,8 +679,12 @@ class EagleProposer: self, num_tokens: int, ) -> None: - with set_forward_context(None, self.vllm_config, - num_tokens=num_tokens): + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=self.cudagraph_runtime_mode, + ): if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens]