mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-20 02:24:34 +08:00
[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
parent
c9c1b59e59
commit
70788bdbdc
@ -36,6 +36,10 @@ def parse_args():
|
||||
help="downloaded from the eagle repo " \
|
||||
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
|
||||
)
|
||||
parser.add_argument("--method",
|
||||
type=str,
|
||||
default='eagle',
|
||||
choices=['eagle', 'eagle3'])
|
||||
parser.add_argument("--max_num_seqs", type=int, default=8)
|
||||
parser.add_argument("--num_prompts", type=int, default=80)
|
||||
parser.add_argument("--num_spec_tokens", type=int, default=2)
|
||||
@ -53,7 +57,13 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
|
||||
if args.method == 'eagle':
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
elif args.method == 'eagle3':
|
||||
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
else:
|
||||
raise ValueError(f"unknown method: {args.method}")
|
||||
|
||||
max_model_len = 2048
|
||||
|
||||
@ -81,7 +91,7 @@ def main():
|
||||
max_num_seqs=args.max_num_seqs,
|
||||
gpu_memory_utilization=0.8,
|
||||
speculative_config={
|
||||
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
|
||||
"method": args.method,
|
||||
"model": eagle_dir,
|
||||
"num_speculative_tokens": args.num_spec_tokens,
|
||||
"draft_tensor_parallel_size": args.draft_tp,
|
||||
|
||||
@ -347,8 +347,12 @@ class VllmBackend:
|
||||
PASS_KEY = "post_grad_custom_post_pass"
|
||||
if PASS_KEY in inductor_config:
|
||||
# Config should automatically wrap all inductor passes
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
|
||||
assert (inductor_config[PASS_KEY].uuid() ==
|
||||
self.post_grad_pass_manager.uuid())
|
||||
else:
|
||||
assert isinstance(inductor_config[PASS_KEY], InductorPass)
|
||||
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
|
||||
inductor_config[PASS_KEY] = self.post_grad_pass_manager
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
@ -408,8 +412,13 @@ class VllmBackend:
|
||||
)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
|
||||
cache_dir = self.compilation_config.cache_dir
|
||||
if compilation_counter.num_graphs_seen > 0:
|
||||
cache_dir = self.compilation_config.cache_dir + \
|
||||
f'-{compilation_counter.num_graphs_seen}'
|
||||
else:
|
||||
cache_dir = self.compilation_config.cache_dir
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.compilation_config.cache_dir = cache_dir
|
||||
rank = vllm_config.parallel_config.rank
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||
|
||||
@ -6,7 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -37,17 +38,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
self.input_layernorm = nn.Identity()
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
start_layer_id: int = 0,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
start_layer_id: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = model_config.hf_config
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
@ -75,8 +78,7 @@ class LlamaModel(nn.Module):
|
||||
hidden_states = self.fc(
|
||||
torch.cat((input_embeds, hidden_states), dim=-1))
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
@ -117,12 +119,13 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
nn.Module.__init__(self)
|
||||
self.config = model_config.hf_config
|
||||
self.model = LlamaModel(model_config=model_config,
|
||||
start_layer_id=start_layer_id,
|
||||
prefix="model")
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix="model",
|
||||
start_layer_id=start_layer_id)
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
|
||||
@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
@ -167,8 +167,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
nn.Module.__init__(self)
|
||||
model_config = vllm_config.speculative_config.draft_model_config
|
||||
self.config = model_config.hf_config
|
||||
self.model = LlamaModel(model_config=model_config,
|
||||
start_layer_id=start_layer_id,
|
||||
|
||||
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||
@ -26,10 +26,41 @@ class EagleProposer:
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.method = self.vllm_config.speculative_config.method
|
||||
self.num_speculative_tokens = (
|
||||
vllm_config.speculative_config.num_speculative_tokens)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.max_num_tokens = vllm_config.scheduler_config \
|
||||
.max_num_batched_tokens
|
||||
|
||||
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
||||
|
||||
# TODO: make eagle3 compatible with cudagraph
|
||||
self.use_cuda_graph = self.method != 'eagle3' and \
|
||||
(self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE and
|
||||
not self.vllm_config.model_config.enforce_eager)
|
||||
|
||||
self.cudagraph_batch_sizes = list(
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||
@ -59,13 +90,12 @@ class EagleProposer:
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = cu_num_tokens[1:] - 1
|
||||
|
||||
input_ids = torch.empty_like(target_token_ids)
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
input_ids[:-1] = target_token_ids[1:]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
input_ids[last_token_indices] = next_token_ids
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
# FA requires seq_len to have dtype int32.
|
||||
seq_lens = (target_positions[last_token_indices] + 1).int()
|
||||
@ -88,14 +118,30 @@ class EagleProposer:
|
||||
prefix_kv_lens=None,
|
||||
suffix_kv_lens=None,
|
||||
)
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states_logits, hidden_states_fwd = self.model(
|
||||
input_ids=input_ids,
|
||||
hidden_states=target_hidden_states,
|
||||
positions=target_positions,
|
||||
if self.method == 'eagle':
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
hidden_states = self.hidden_states
|
||||
else:
|
||||
# TODO: make eagle3 compatible with cuda graph
|
||||
hidden_states = target_hidden_states
|
||||
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=hidden_states[:num_input_tokens],
|
||||
)
|
||||
sample_hidden_states = hidden_states_logits[last_token_indices]
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
@ -108,13 +154,20 @@ class EagleProposer:
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states_fwd[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
input_ids = draft_token_ids_list[-1]
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
input_ids = draft_token_ids_list[-1].int()
|
||||
positions += 1
|
||||
|
||||
# NOTE(woosuk): We should handle the case where the draft model
|
||||
@ -152,14 +205,27 @@ class EagleProposer:
|
||||
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
|
||||
if self.method == 'eagle':
|
||||
# TODO: make eagle3 compatible with cudagraph.
|
||||
self.hidden_states[:batch_size] = hidden_states
|
||||
hidden_states = self.hidden_states
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states_logits, hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
hidden_states=hidden_states,
|
||||
positions=clamped_positions,
|
||||
with set_forward_context(attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:input_batch_size],
|
||||
positions=self.positions[:input_batch_size],
|
||||
hidden_states=hidden_states[:input_batch_size],
|
||||
)
|
||||
logits = self.model.compute_logits(hidden_states_logits, None)
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
@ -227,13 +293,11 @@ class EagleProposer:
|
||||
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
|
||||
draft_model_config.architectures)
|
||||
self.model = draft_model_cls(
|
||||
model_config=draft_model_config,
|
||||
vllm_config=self.vllm_config,
|
||||
start_layer_id=target_layer_num).to(target_device)
|
||||
|
||||
loaded_weights = self.model.load_weights(
|
||||
loader.get_all_weights(
|
||||
self.vllm_config.speculative_config.draft_model_config,
|
||||
self.model))
|
||||
loader.get_all_weights(draft_model_config, self.model))
|
||||
if self.vllm_config.speculative_config.method == "eagle3":
|
||||
if "model.embed_tokens.weight" not in loaded_weights:
|
||||
logger.info(
|
||||
@ -243,6 +307,20 @@ class EagleProposer:
|
||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||
self.model.lm_head = target_model.lm_head
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
with set_forward_context(None, self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
if self.method == 'eagle':
|
||||
self.model(
|
||||
input_ids=self.input_ids[:num_tokens],
|
||||
positions=self.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
)
|
||||
|
||||
|
||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||
# to sample the draft tokens. We will use this after we find a way to manage
|
||||
|
||||
@ -1106,7 +1106,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
return hidden_states
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
@ -1172,7 +1171,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states,
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
@ -1222,15 +1221,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
# input_ids can be None for multimodal models.
|
||||
# We need to slice token_ids, positions, and hidden_states
|
||||
# because the eagle head does not use cuda graph and should
|
||||
# not include padding.
|
||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = [
|
||||
h[:num_scheduled_tokens] for h in aux_hidden_states
|
||||
]
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
target_slot_mapping = attn_metadata.slot_mapping
|
||||
@ -1254,15 +1250,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
target_token_ids = self.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = [
|
||||
h[token_indices] for h in aux_hidden_states
|
||||
]
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(target_hidden_states, dim=-1)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
@ -1506,6 +1499,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
hidden_states = outputs
|
||||
|
||||
if self.use_spec_decode and \
|
||||
self.speculative_config.method in ('eagle', 'eagle3'):
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
|
||||
logit_indices = np.cumsum(num_scheduled_tokens) - 1
|
||||
return hidden_states[logit_indices]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user