enable piecewise cudagraphs for eagle

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-09-17 20:48:14 +00:00
parent 883131544f
commit 90d24dee04
2 changed files with 37 additions and 12 deletions

View File

@ -321,6 +321,13 @@ def set_forward_context(
attn_metadata, num_tokens or 0,
num_tokens_across_dp)
# Convienience: if cudagraph is used, and num_tokens is given, we can just
# create a batch descriptor here if not given (there's no harm since if it
# doesn't match in the wrapper it'll fall through).
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
batch_descriptor = batch_descriptor or BatchDescriptor(
num_tokens=num_tokens)
forward_context = create_forward_context(attn_metadata, vllm_config,
virtual_engine, dp_metadata,
cudagraph_runtime_mode,

View File

@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
@ -78,6 +78,11 @@ class EagleProposer:
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.cudagraph_runtime_mode = (CUDAGraphMode.PIECEWISE
if self.use_cuda_graph else
CUDAGraphMode.NONE)
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
@ -212,9 +217,12 @@ class EagleProposer:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
@ -322,9 +330,12 @@ class EagleProposer:
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:input_batch_size],
@ -478,9 +489,12 @@ class EagleProposer:
else:
num_input_tokens = num_tokens
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_forward_context(
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
@ -665,8 +679,12 @@ class EagleProposer:
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=self.cudagraph_runtime_mode,
):
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]