[Spec-Decode] Support piecewise cudagraphs for Eagle head (#25109)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Lucas Wilkinson 2025-10-10 01:20:31 -04:00 committed by GitHub
parent da4455609d
commit 29255cfc3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 84 additions and 16 deletions

View File

@ -50,11 +50,14 @@ class CUDAGraphMode(enum.Enum):
def mixed_mode(self) -> "CUDAGraphMode":
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
def has_mode(self, mode: "CUDAGraphMode") -> bool:
assert not mode.separate_routine()
if self.separate_routine():
return mode.value in self.value
return self == mode
def requires_piecewise_compilation(self) -> bool:
return (
self.decode_mode() == CUDAGraphMode.PIECEWISE
or self.mixed_mode() == CUDAGraphMode.PIECEWISE
)
return self.has_mode(CUDAGraphMode.PIECEWISE)
def max_cudagraph_mode(self) -> "CUDAGraphMode":
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self

View File

@ -283,6 +283,12 @@ def set_forward_context(
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
)
# Convenience: if cudagraph is used and num_tokens is given, we can just
# create a batch descriptor here if not given (there's no harm since if it
# doesn't match in the wrapper it'll fall through).
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)
forward_context = create_forward_context(
attn_metadata,
vllm_config,

View File

@ -7,6 +7,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -162,6 +163,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
return logits
@support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()

View File

@ -8,6 +8,7 @@ import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
@ -21,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix
@ -119,6 +121,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
return hidden_states, residual
@support_torch_compile(
# torch.compile is disabled for multimodal EAGLE3 models due to constraint
# violations with dynamic shapes during tensor concatenation operations.
# See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
# Non-multimodal EAGLE3 models can still use torch.compile safely.
enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
vllm_config.model_config
),
)
class LlamaModel(nn.Module):
def __init__(
self,

View File

@ -9,7 +9,12 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
from vllm.config import (
CompilationLevel,
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
@ -80,12 +85,25 @@ class EagleProposer:
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.use_cuda_graph = (
not current_platform.is_xpu()
and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
and not self.vllm_config.model_config.enforce_eager
and not self.speculative_config.enforce_eager
)
self.use_cuda_graph = False
compilation_config = self.vllm_config.compilation_config
if compilation_config.level == CompilationLevel.PIECEWISE:
cudagraph_mode = compilation_config.cudagraph_mode
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
CUDAGraphMode.PIECEWISE
):
logger.warning(
"Currently the eagle proposer only supports cudagraph_mode "
"PIECEWISE, if you want the drafter to use cuda graphs, "
"please set compilation_config.cudagraph_mode to PIECEWISE "
"or FULL_AND_PIECEWISE"
)
self.use_cuda_graph = (
cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
and not self.speculative_config.enforce_eager
)
self.cudagraph_batch_sizes = (
list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))
if self.use_cuda_graph
@ -239,12 +257,15 @@ class EagleProposer:
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
cudagraph_runtime_mode = CUDAGraphMode.NONE
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
@ -267,7 +288,10 @@ class EagleProposer:
inputs_embeds = None
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
input_ids=input_ids,
@ -326,8 +350,10 @@ class EagleProposer:
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
input_batch_size = batch_size
cudagraph_runtime_mode = CUDAGraphMode.NONE
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
@ -424,7 +450,10 @@ class EagleProposer:
# Run the model.
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
input_ids=input_ids,
@ -731,11 +760,16 @@ class EagleProposer:
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
else:
num_input_tokens = num_tokens
cudagraph_runtime_mode = CUDAGraphMode.NONE
# Run the model.
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
@ -1015,8 +1049,19 @@ class EagleProposer:
def dummy_run(
self,
num_tokens: int,
use_cudagraphs=True,
) -> None:
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]:
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
if use_cudagraphs
else CUDAGraphMode.NONE,
):
if self.supports_mm_inputs:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]

View File

@ -3441,7 +3441,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)
use_cudagraphs = cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs)
# This is necessary to avoid blocking DP.
# For dummy runs, we typically skip EPLB since we don't have any real