[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Bryan Lu 2025-04-29 14:10:00 -07:00 committed by GitHub
parent c9c1b59e59
commit 70788bdbdc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 152 additions and 53 deletions

View File

@ -36,6 +36,10 @@ def parse_args():
help="downloaded from the eagle repo " \ help="downloaded from the eagle repo " \
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
) )
parser.add_argument("--method",
type=str,
default='eagle',
choices=['eagle', 'eagle3'])
parser.add_argument("--max_num_seqs", type=int, default=8) parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80) parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2) parser.add_argument("--num_spec_tokens", type=int, default=2)
@ -53,7 +57,13 @@ def main():
args = parse_args() args = parse_args()
model_dir = "meta-llama/Llama-3.1-8B-Instruct" model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
if args.method == 'eagle':
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
elif args.method == 'eagle3':
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
else:
raise ValueError(f"unknown method: {args.method}")
max_model_len = 2048 max_model_len = 2048
@ -81,7 +91,7 @@ def main():
max_num_seqs=args.max_num_seqs, max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8, gpu_memory_utilization=0.8,
speculative_config={ speculative_config={
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle", "method": args.method,
"model": eagle_dir, "model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens, "num_speculative_tokens": args.num_spec_tokens,
"draft_tensor_parallel_size": args.draft_tp, "draft_tensor_parallel_size": args.draft_tp,

View File

@ -347,8 +347,12 @@ class VllmBackend:
PASS_KEY = "post_grad_custom_post_pass" PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config: if PASS_KEY in inductor_config:
# Config should automatically wrap all inductor passes # Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass) if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) assert (inductor_config[PASS_KEY].uuid() ==
self.post_grad_pass_manager.uuid())
else:
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager inductor_config[PASS_KEY] = self.post_grad_pass_manager
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
@ -408,8 +412,13 @@ class VllmBackend:
) )
self.compilation_config.cache_dir = cache_dir self.compilation_config.cache_dir = cache_dir
cache_dir = self.compilation_config.cache_dir if compilation_counter.num_graphs_seen > 0:
cache_dir = self.compilation_config.cache_dir + \
f'-{compilation_counter.num_graphs_seen}'
else:
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True) os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir
rank = vllm_config.parallel_config.rank rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")

View File

@ -6,7 +6,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.config import ModelConfig from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -37,17 +38,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
self.input_layernorm = nn.Identity() self.input_layernorm = nn.Identity()
@support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__( def __init__(
self, self,
*, *,
model_config: ModelConfig, vllm_config: VllmConfig,
start_layer_id: int = 0,
prefix: str = "", prefix: str = "",
start_layer_id: int = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = model_config.hf_config self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size, self.config.vocab_size,
@ -75,8 +78,7 @@ class LlamaModel(nn.Module):
hidden_states = self.fc( hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1)) torch.cat((input_embeds, hidden_states), dim=-1))
residual = None residual = None
for i in range(len(self.layers)): for layer in self.layers:
layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
@ -117,12 +119,13 @@ class LlamaModel(nn.Module):
class EagleLlamaForCausalLM(LlamaForCausalLM): class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = model_config.hf_config self.config = vllm_config. \
self.model = LlamaModel(model_config=model_config, speculative_config.draft_model_config.hf_config
start_layer_id=start_layer_id, self.model = LlamaModel(vllm_config=vllm_config,
prefix="model") prefix="model",
start_layer_id=start_layer_id)
logit_scale = getattr(self.config, "logit_scale", 1.0) logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size, self.logits_processor = LogitsProcessor(self.config.vocab_size,

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.linear import QKVParallelLinear
@ -167,8 +167,9 @@ class LlamaModel(nn.Module):
class Eagle3LlamaForCausalLM(LlamaForCausalLM): class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self) nn.Module.__init__(self)
model_config = vllm_config.speculative_config.draft_model_config
self.config = model_config.hf_config self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config, self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id, start_layer_id=start_layer_id,

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.config import VllmConfig, set_current_vllm_config from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.loader import get_model_loader from vllm.model_executor.model_loader.loader import get_model_loader
@ -26,10 +26,41 @@ class EagleProposer:
device: torch.device, device: torch.device,
): ):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.method = self.vllm_config.speculative_config.method
self.num_speculative_tokens = ( self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens) vllm_config.speculative_config.num_speculative_tokens)
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
self.max_num_tokens = vllm_config.scheduler_config \
.max_num_batched_tokens
self.hidden_size = vllm_config.model_config.get_hidden_size()
# TODO: make eagle3 compatible with cudagraph
self.use_cuda_graph = self.method != 'eagle3' and \
(self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# persistent buffers for cuda graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
# We need +1 here because the arange is used to set query_start_loc, # We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size. # which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
@ -59,13 +90,12 @@ class EagleProposer:
batch_size = next_token_ids.shape[0] batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1 last_token_indices = cu_num_tokens[1:] - 1
input_ids = torch.empty_like(target_token_ids)
# Shift the input ids by one token. # Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:] self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token. # Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids self.input_ids[last_token_indices] = next_token_ids
# FA requires seq_len to have dtype int32. # FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int() seq_lens = (target_positions[last_token_indices] + 1).int()
@ -88,14 +118,30 @@ class EagleProposer:
prefix_kv_lens=None, prefix_kv_lens=None,
suffix_kv_lens=None, suffix_kv_lens=None,
) )
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
with set_forward_context(attn_metadata, self.vllm_config): if self.method == 'eagle':
hidden_states_logits, hidden_states_fwd = self.model( self.hidden_states[:num_tokens] = target_hidden_states
input_ids=input_ids, hidden_states = self.hidden_states
hidden_states=target_hidden_states, else:
positions=target_positions, # TODO: make eagle3 compatible with cuda graph
hidden_states = target_hidden_states
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=hidden_states[:num_input_tokens],
) )
sample_hidden_states = hidden_states_logits[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
@ -108,13 +154,20 @@ class EagleProposer:
draft_token_ids_list = [draft_token_ids] draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices] positions = target_positions[last_token_indices]
hidden_states = hidden_states_fwd[last_token_indices] hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1 attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1] attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1): for _ in range(self.num_speculative_tokens - 1):
# Update the inputs. # Update the inputs.
input_ids = draft_token_ids_list[-1] # cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1 positions += 1
# NOTE(woosuk): We should handle the case where the draft model # NOTE(woosuk): We should handle the case where the draft model
@ -152,14 +205,27 @@ class EagleProposer:
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID) PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
if self.method == 'eagle':
# TODO: make eagle3 compatible with cudagraph.
self.hidden_states[:batch_size] = hidden_states
hidden_states = self.hidden_states
# Run the model. # Run the model.
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata,
hidden_states_logits, hidden_states = self.model( self.vllm_config,
input_ids=input_ids, num_tokens=input_batch_size):
hidden_states=hidden_states, last_hidden_states, hidden_states = self.model(
positions=clamped_positions, input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=hidden_states[:input_batch_size],
) )
logits = self.model.compute_logits(hidden_states_logits, None) hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids) draft_token_ids_list.append(draft_token_ids)
@ -227,13 +293,11 @@ class EagleProposer:
draft_model_cls, arch = ModelRegistry.resolve_model_cls( draft_model_cls, arch = ModelRegistry.resolve_model_cls(
draft_model_config.architectures) draft_model_config.architectures)
self.model = draft_model_cls( self.model = draft_model_cls(
model_config=draft_model_config, vllm_config=self.vllm_config,
start_layer_id=target_layer_num).to(target_device) start_layer_id=target_layer_num).to(target_device)
loaded_weights = self.model.load_weights( loaded_weights = self.model.load_weights(
loader.get_all_weights( loader.get_all_weights(draft_model_config, self.model))
self.vllm_config.speculative_config.draft_model_config,
self.model))
if self.vllm_config.speculative_config.method == "eagle3": if self.vllm_config.speculative_config.method == "eagle3":
if "model.embed_tokens.weight" not in loaded_weights: if "model.embed_tokens.weight" not in loaded_weights:
logger.info( logger.info(
@ -243,6 +307,20 @@ class EagleProposer:
logger.info("Loading EAGLE LM head weights from the target model.") logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_model.lm_head self.model.lm_head = target_model.lm_head
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.method == 'eagle':
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
# NOTE(woosuk): Currently, the below code is not used and we always use argmax # NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage # to sample the draft tokens. We will use this after we find a way to manage

View File

@ -1106,7 +1106,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# For mid-pipeline stages, return the hidden states. # For mid-pipeline stages, return the hidden states.
return hidden_states return hidden_states
hidden_states = hidden_states[:num_scheduled_tokens]
sample_hidden_states = hidden_states[logits_indices] sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None) logits = self.model.compute_logits(sample_hidden_states, None)
@ -1172,7 +1171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Compute prompt logprobs if needed. # Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict( prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states, hidden_states[:num_scheduled_tokens],
scheduler_output, scheduler_output,
) )
@ -1222,15 +1221,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if spec_decode_metadata is None: if spec_decode_metadata is None:
# input_ids can be None for multimodal models. # input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids = self.input_ids[:num_scheduled_tokens] target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens] target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
target_hidden_states = [ target_hidden_states = torch.cat(
h[:num_scheduled_tokens] for h in aux_hidden_states [h[:num_scheduled_tokens] for h in aux_hidden_states],
] dim=-1)
else: else:
target_hidden_states = hidden_states[:num_scheduled_tokens] target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping target_slot_mapping = attn_metadata.slot_mapping
@ -1254,15 +1250,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_token_ids = self.input_ids[token_indices] target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices] target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
target_hidden_states = [ target_hidden_states = torch.cat(
h[token_indices] for h in aux_hidden_states [h[token_indices] for h in aux_hidden_states], dim=-1)
]
else: else:
target_hidden_states = hidden_states[token_indices] target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids, target_token_ids=target_token_ids,
target_positions=target_positions, target_positions=target_positions,
@ -1506,6 +1499,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else: else:
hidden_states = outputs hidden_states = outputs
if self.use_spec_decode and \
self.speculative_config.method in ('eagle', 'eagle3'):
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)
logit_indices = np.cumsum(num_scheduled_tokens) - 1 logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states[logit_indices] return hidden_states[logit_indices]