mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:25:01 +08:00
[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE3 (#17504)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
parent
9b70e2b4c1
commit
39c0813a7f
@ -6,7 +6,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
from vllm.config import ModelConfig, VllmConfig
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||||
@ -76,17 +77,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
model_config: ModelConfig,
|
vllm_config: VllmConfig,
|
||||||
start_layer_id: int = 0,
|
start_layer_id: int = 0,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = model_config.hf_config
|
self.config = vllm_config. \
|
||||||
|
speculative_config.draft_model_config.hf_config
|
||||||
self.vocab_size = self.config.vocab_size
|
self.vocab_size = self.config.vocab_size
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
self.config.vocab_size,
|
self.config.vocab_size,
|
||||||
@ -119,8 +122,7 @@ class LlamaModel(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
input_embeds = self.embed_tokens(input_ids)
|
input_embeds = self.embed_tokens(input_ids)
|
||||||
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
|
assert hidden_states.shape[-1] == input_embeds.shape[-1]
|
||||||
hidden_states = self.fc(hidden_states)
|
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
hidden_states, residual = self.layers[0](
|
hidden_states, residual = self.layers[0](
|
||||||
@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
model_config = vllm_config.speculative_config.draft_model_config
|
self.config = vllm_config. \
|
||||||
self.config = model_config.hf_config
|
speculative_config.draft_model_config.hf_config
|
||||||
self.model = LlamaModel(model_config=model_config,
|
self.model = LlamaModel(vllm_config=vllm_config,
|
||||||
start_layer_id=start_layer_id,
|
start_layer_id=start_layer_id,
|
||||||
prefix="model")
|
prefix="model")
|
||||||
|
|
||||||
@ -214,6 +216,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
|||||||
logits_new[:, targets] = logits
|
logits_new[:, targets] = logits
|
||||||
return logits_new
|
return logits_new
|
||||||
|
|
||||||
|
def combine_hidden_states(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# combine multiple auxiliary hidden states returned by eagle3
|
||||||
|
return self.model.fc(hidden_states)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
loader = AutoWeightsLoader(
|
loader = AutoWeightsLoader(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.model_loader.loader import get_model_loader
|
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
|
||||||
@ -39,11 +40,9 @@ class EagleProposer:
|
|||||||
|
|
||||||
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
self.hidden_size = vllm_config.model_config.get_hidden_size()
|
||||||
|
|
||||||
# TODO: make eagle3 compatible with cudagraph
|
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
self.use_cuda_graph = self.method != 'eagle3' and \
|
== CompilationLevel.PIECEWISE and
|
||||||
(self.vllm_config.compilation_config.level
|
not self.vllm_config.model_config.enforce_eager)
|
||||||
== CompilationLevel.PIECEWISE and
|
|
||||||
not self.vllm_config.model_config.enforce_eager)
|
|
||||||
|
|
||||||
self.cudagraph_batch_sizes = list(
|
self.cudagraph_batch_sizes = list(
|
||||||
reversed(
|
reversed(
|
||||||
@ -90,6 +89,12 @@ class EagleProposer:
|
|||||||
batch_size = next_token_ids.shape[0]
|
batch_size = next_token_ids.shape[0]
|
||||||
last_token_indices = cu_num_tokens[1:] - 1
|
last_token_indices = cu_num_tokens[1:] - 1
|
||||||
|
|
||||||
|
if self.method == "eagle3":
|
||||||
|
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||||
|
target_hidden_states = self.model.combine_hidden_states(
|
||||||
|
target_hidden_states)
|
||||||
|
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||||
|
|
||||||
# Shift the input ids by one token.
|
# Shift the input ids by one token.
|
||||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||||
@ -126,12 +131,7 @@ class EagleProposer:
|
|||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self.positions[:num_tokens] = target_positions
|
self.positions[:num_tokens] = target_positions
|
||||||
|
|
||||||
if self.method == 'eagle':
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
|
||||||
hidden_states = self.hidden_states
|
|
||||||
else:
|
|
||||||
# TODO: make eagle3 compatible with cuda graph
|
|
||||||
hidden_states = target_hidden_states
|
|
||||||
|
|
||||||
with set_forward_context(attn_metadata,
|
with set_forward_context(attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
@ -139,7 +139,7 @@ class EagleProposer:
|
|||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
hidden_states=hidden_states[:num_input_tokens],
|
hidden_states=self.hidden_states[:num_input_tokens],
|
||||||
)
|
)
|
||||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
@ -209,10 +209,7 @@ class EagleProposer:
|
|||||||
self.input_ids[:batch_size] = input_ids
|
self.input_ids[:batch_size] = input_ids
|
||||||
self.positions[:batch_size] = clamped_positions
|
self.positions[:batch_size] = clamped_positions
|
||||||
|
|
||||||
if self.method == 'eagle':
|
self.hidden_states[:batch_size] = hidden_states
|
||||||
# TODO: make eagle3 compatible with cudagraph.
|
|
||||||
self.hidden_states[:batch_size] = hidden_states
|
|
||||||
hidden_states = self.hidden_states
|
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(attn_metadata,
|
with set_forward_context(attn_metadata,
|
||||||
@ -221,7 +218,7 @@ class EagleProposer:
|
|||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:input_batch_size],
|
input_ids=self.input_ids[:input_batch_size],
|
||||||
positions=self.positions[:input_batch_size],
|
positions=self.positions[:input_batch_size],
|
||||||
hidden_states=hidden_states[:input_batch_size],
|
hidden_states=self.hidden_states[:input_batch_size],
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states[:batch_size]
|
hidden_states = hidden_states[:batch_size]
|
||||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||||
@ -314,12 +311,11 @@ class EagleProposer:
|
|||||||
) -> None:
|
) -> None:
|
||||||
with set_forward_context(None, self.vllm_config,
|
with set_forward_context(None, self.vllm_config,
|
||||||
num_tokens=num_tokens):
|
num_tokens=num_tokens):
|
||||||
if self.method == 'eagle':
|
self.model(
|
||||||
self.model(
|
input_ids=self.input_ids[:num_tokens],
|
||||||
input_ids=self.input_ids[:num_tokens],
|
positions=self.positions[:num_tokens],
|
||||||
positions=self.positions[:num_tokens],
|
hidden_states=self.hidden_states[:num_tokens],
|
||||||
hidden_states=self.hidden_states[:num_tokens],
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user