[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Bryan Lu 2025-04-29 14:10:00 -07:00 committed by GitHub
parent c9c1b59e59
commit 70788bdbdc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 152 additions and 53 deletions

View File

@ -36,6 +36,10 @@ def parse_args():
help="downloaded from the eagle repo " \
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
)
parser.add_argument("--method",
type=str,
default='eagle',
choices=['eagle', 'eagle3'])
parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2)
@ -53,7 +57,13 @@ def main():
args = parse_args()
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
if args.method == 'eagle':
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
elif args.method == 'eagle3':
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
else:
raise ValueError(f"unknown method: {args.method}")
max_model_len = 2048
@ -81,7 +91,7 @@ def main():
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_config={
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"draft_tensor_parallel_size": args.draft_tp,

View File

@ -347,8 +347,12 @@ class VllmBackend:
PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
assert (inductor_config[PASS_KEY].uuid() ==
self.post_grad_pass_manager.uuid())
else:
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
@ -408,8 +412,13 @@ class VllmBackend:
)
self.compilation_config.cache_dir = cache_dir
cache_dir = self.compilation_config.cache_dir
if compilation_counter.num_graphs_seen > 0:
cache_dir = self.compilation_config.cache_dir + \
f'-{compilation_counter.num_graphs_seen}'
else:
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir
rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")

View File

@ -6,7 +6,8 @@ import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -37,17 +38,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
self.input_layernorm = nn.Identity()
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(
self,
*,
model_config: ModelConfig,
start_layer_id: int = 0,
vllm_config: VllmConfig,
prefix: str = "",
start_layer_id: int = 0,
) -> None:
super().__init__()
self.config = model_config.hf_config
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
@ -75,8 +78,7 @@ class LlamaModel(nn.Module):
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
for layer in self.layers:
hidden_states, residual = layer(
positions,
hidden_states,
@ -117,12 +119,13 @@ class LlamaModel(nn.Module):
class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
prefix="model")
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=start_layer_id)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
@ -167,8 +167,9 @@ class LlamaModel(nn.Module):
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
model_config = vllm_config.speculative_config.draft_model_config
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,

View File

@ -4,7 +4,7 @@ import torch.nn as nn
import triton
import triton.language as tl
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader.loader import get_model_loader
@ -26,10 +26,41 @@ class EagleProposer:
device: torch.device,
):
self.vllm_config = vllm_config
self.method = self.vllm_config.speculative_config.method
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
self.max_num_tokens = vllm_config.scheduler_config \
.max_num_batched_tokens
self.hidden_size = vllm_config.model_config.get_hidden_size()
# TODO: make eagle3 compatible with cudagraph
self.use_cuda_graph = self.method != 'eagle3' and \
(self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# persistent buffers for cuda graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
@ -59,13 +90,12 @@ class EagleProposer:
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
input_ids = torch.empty_like(target_token_ids)
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids
self.input_ids[last_token_indices] = next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
@ -88,14 +118,30 @@ class EagleProposer:
prefix_kv_lens=None,
suffix_kv_lens=None,
)
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states_logits, hidden_states_fwd = self.model(
input_ids=input_ids,
hidden_states=target_hidden_states,
positions=target_positions,
if self.method == 'eagle':
self.hidden_states[:num_tokens] = target_hidden_states
hidden_states = self.hidden_states
else:
# TODO: make eagle3 compatible with cuda graph
hidden_states = target_hidden_states
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=hidden_states[:num_input_tokens],
)
sample_hidden_states = hidden_states_logits[last_token_indices]
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
@ -108,13 +154,20 @@ class EagleProposer:
draft_token_ids_list = [draft_token_ids]
positions = target_positions[last_token_indices]
hidden_states = hidden_states_fwd[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
@ -152,14 +205,27 @@ class EagleProposer:
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
if self.method == 'eagle':
# TODO: make eagle3 compatible with cudagraph.
self.hidden_states[:batch_size] = hidden_states
hidden_states = self.hidden_states
# Run the model.
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states_logits, hidden_states = self.model(
input_ids=input_ids,
hidden_states=hidden_states,
positions=clamped_positions,
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=hidden_states[:input_batch_size],
)
logits = self.model.compute_logits(hidden_states_logits, None)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
@ -227,13 +293,11 @@ class EagleProposer:
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
draft_model_config.architectures)
self.model = draft_model_cls(
model_config=draft_model_config,
vllm_config=self.vllm_config,
start_layer_id=target_layer_num).to(target_device)
loaded_weights = self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
loader.get_all_weights(draft_model_config, self.model))
if self.vllm_config.speculative_config.method == "eagle3":
if "model.embed_tokens.weight" not in loaded_weights:
logger.info(
@ -243,6 +307,20 @@ class EagleProposer:
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_model.lm_head
@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.method == 'eagle':
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage

View File

@ -1106,7 +1106,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# For mid-pipeline stages, return the hidden states.
return hidden_states
hidden_states = hidden_states[:num_scheduled_tokens]
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
@ -1172,7 +1171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states,
hidden_states[:num_scheduled_tokens],
scheduler_output,
)
@ -1222,15 +1221,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
# We need to slice token_ids, positions, and hidden_states
# because the eagle head does not use cuda graph and should
# not include padding.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = [
h[:num_scheduled_tokens] for h in aux_hidden_states
]
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = attn_metadata.slot_mapping
@ -1254,15 +1250,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = [
h[token_indices] for h in aux_hidden_states
]
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
@ -1506,6 +1499,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
hidden_states = outputs
if self.use_spec_decode and \
self.speculative_config.method in ('eagle', 'eagle3'):
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
return hidden_states[logit_indices]