[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE3 (#17504)

Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
qizixi 2025-05-01 16:19:30 -07:00 committed by GitHub
parent 9b70e2b4c1
commit 39c0813a7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 31 deletions

View File

@ -6,7 +6,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.config import ModelConfig, VllmConfig from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear
@ -76,17 +77,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__( def __init__(
self, self,
*, *,
model_config: ModelConfig, vllm_config: VllmConfig,
start_layer_id: int = 0, start_layer_id: int = 0,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = model_config.hf_config self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size, self.config.vocab_size,
@ -119,8 +122,7 @@ class LlamaModel(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids) input_embeds = self.embed_tokens(input_ids)
if (hidden_states.shape[-1] != input_embeds.shape[-1]): assert hidden_states.shape[-1] == input_embeds.shape[-1]
hidden_states = self.fc(hidden_states)
residual = None residual = None
hidden_states, residual = self.layers[0]( hidden_states, residual = self.layers[0](
@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0): def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self) nn.Module.__init__(self)
model_config = vllm_config.speculative_config.draft_model_config self.config = vllm_config. \
self.config = model_config.hf_config speculative_config.draft_model_config.hf_config
self.model = LlamaModel(model_config=model_config, self.model = LlamaModel(vllm_config=vllm_config,
start_layer_id=start_layer_id, start_layer_id=start_layer_id,
prefix="model") prefix="model")
@ -214,6 +216,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
logits_new[:, targets] = logits logits_new[:, targets] = logits
return logits_new return logits_new
def combine_hidden_states(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader( loader = AutoWeightsLoader(
self, self,

View File

@ -10,6 +10,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@ -39,11 +40,9 @@ class EagleProposer:
self.hidden_size = vllm_config.model_config.get_hidden_size() self.hidden_size = vllm_config.model_config.get_hidden_size()
# TODO: make eagle3 compatible with cudagraph self.use_cuda_graph = (self.vllm_config.compilation_config.level
self.use_cuda_graph = self.method != 'eagle3' and \ == CompilationLevel.PIECEWISE and
(self.vllm_config.compilation_config.level not self.vllm_config.model_config.enforce_eager)
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.cudagraph_batch_sizes = list( self.cudagraph_batch_sizes = list(
reversed( reversed(
@ -90,6 +89,12 @@ class EagleProposer:
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1 last_token_indices = cu_num_tokens[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token. # Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:] self.input_ids[:num_tokens - 1] = target_token_ids[1:]
@ -126,12 +131,7 @@ class EagleProposer:
# copy inputs to buffer for cudagraph # copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions self.positions[:num_tokens] = target_positions
if self.method == 'eagle': self.hidden_states[:num_tokens] = target_hidden_states
self.hidden_states[:num_tokens] = target_hidden_states
hidden_states = self.hidden_states
else:
# TODO: make eagle3 compatible with cuda graph
hidden_states = target_hidden_states
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
self.vllm_config, self.vllm_config,
@ -139,7 +139,7 @@ class EagleProposer:
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens], positions=self.positions[:num_input_tokens],
hidden_states=hidden_states[:num_input_tokens], hidden_states=self.hidden_states[:num_input_tokens],
) )
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
@ -209,10 +209,7 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions self.positions[:batch_size] = clamped_positions
if self.method == 'eagle': self.hidden_states[:batch_size] = hidden_states
# TODO: make eagle3 compatible with cudagraph.
self.hidden_states[:batch_size] = hidden_states
hidden_states = self.hidden_states
# Run the model. # Run the model.
with set_forward_context(attn_metadata, with set_forward_context(attn_metadata,
@ -221,7 +218,7 @@ class EagleProposer:
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size], input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size], positions=self.positions[:input_batch_size],
hidden_states=hidden_states[:input_batch_size], hidden_states=self.hidden_states[:input_batch_size],
) )
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size], logits = self.model.compute_logits(last_hidden_states[:batch_size],
@ -314,12 +311,11 @@ class EagleProposer:
) -> None: ) -> None:
with set_forward_context(None, self.vllm_config, with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens): num_tokens=num_tokens):
if self.method == 'eagle': self.model(
self.model( input_ids=self.input_ids[:num_tokens],
input_ids=self.input_ids[:num_tokens], positions=self.positions[:num_tokens],
positions=self.positions[:num_tokens], hidden_states=self.hidden_states[:num_tokens],
hidden_states=self.hidden_states[:num_tokens], )
)
# NOTE(woosuk): Currently, the below code is not used and we always use argmax # NOTE(woosuk): Currently, the below code is not used and we always use argmax