mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[Spec-Decode] Support piecewise cudagraphs for Eagle head (#25109)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
da4455609d
commit
29255cfc3b
@ -50,11 +50,14 @@ class CUDAGraphMode(enum.Enum):
|
||||
def mixed_mode(self) -> "CUDAGraphMode":
|
||||
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
|
||||
|
||||
def has_mode(self, mode: "CUDAGraphMode") -> bool:
|
||||
assert not mode.separate_routine()
|
||||
if self.separate_routine():
|
||||
return mode.value in self.value
|
||||
return self == mode
|
||||
|
||||
def requires_piecewise_compilation(self) -> bool:
|
||||
return (
|
||||
self.decode_mode() == CUDAGraphMode.PIECEWISE
|
||||
or self.mixed_mode() == CUDAGraphMode.PIECEWISE
|
||||
)
|
||||
return self.has_mode(CUDAGraphMode.PIECEWISE)
|
||||
|
||||
def max_cudagraph_mode(self) -> "CUDAGraphMode":
|
||||
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
|
||||
|
||||
@ -283,6 +283,12 @@ def set_forward_context(
|
||||
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
|
||||
)
|
||||
|
||||
# Convenience: if cudagraph is used and num_tokens is given, we can just
|
||||
# create a batch descriptor here if not given (there's no harm since if it
|
||||
# doesn't match in the wrapper it'll fall through).
|
||||
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
|
||||
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)
|
||||
|
||||
forward_context = create_forward_context(
|
||||
attn_metadata,
|
||||
vllm_config,
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -162,6 +163,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
||||
return logits
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class DeepSeekMTP(nn.Module, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -21,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
@ -119,6 +121,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
# torch.compile is disabled for multimodal EAGLE3 models due to constraint
|
||||
# violations with dynamic shapes during tensor concatenation operations.
|
||||
# See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
|
||||
# Non-multimodal EAGLE3 models can still use torch.compile safely.
|
||||
enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
|
||||
vllm_config.model_config
|
||||
),
|
||||
)
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -9,7 +9,12 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.config import (
|
||||
CompilationLevel,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
get_layers_from_vllm_config,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@ -80,12 +85,25 @@ class EagleProposer:
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
|
||||
self.use_cuda_graph = (
|
||||
not current_platform.is_xpu()
|
||||
and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
|
||||
and not self.vllm_config.model_config.enforce_eager
|
||||
and not self.speculative_config.enforce_eager
|
||||
)
|
||||
self.use_cuda_graph = False
|
||||
|
||||
compilation_config = self.vllm_config.compilation_config
|
||||
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
cudagraph_mode = compilation_config.cudagraph_mode
|
||||
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
|
||||
CUDAGraphMode.PIECEWISE
|
||||
):
|
||||
logger.warning(
|
||||
"Currently the eagle proposer only supports cudagraph_mode "
|
||||
"PIECEWISE, if you want the drafter to use cuda graphs, "
|
||||
"please set compilation_config.cudagraph_mode to PIECEWISE "
|
||||
"or FULL_AND_PIECEWISE"
|
||||
)
|
||||
self.use_cuda_graph = (
|
||||
cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
|
||||
and not self.speculative_config.enforce_eager
|
||||
)
|
||||
|
||||
self.cudagraph_batch_sizes = (
|
||||
list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
if self.use_cuda_graph
|
||||
@ -239,12 +257,15 @@ class EagleProposer:
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
for layer_name in self.indexer_layer_names:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
# copy inputs to buffer for cudagraph
|
||||
@ -267,7 +288,10 @@ class EagleProposer:
|
||||
inputs_embeds = None
|
||||
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -326,8 +350,10 @@ class EagleProposer:
|
||||
|
||||
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
@ -424,7 +450,10 @@ class EagleProposer:
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -731,11 +760,16 @@ class EagleProposer:
|
||||
|
||||
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
# Run the model.
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
|
||||
per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
):
|
||||
last_hidden_states, hidden_states = self.model(
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
@ -1015,8 +1049,19 @@ class EagleProposer:
|
||||
def dummy_run(
|
||||
self,
|
||||
num_tokens: int,
|
||||
use_cudagraphs=True,
|
||||
) -> None:
|
||||
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
|
||||
if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
|
||||
if use_cudagraphs
|
||||
else CUDAGraphMode.NONE,
|
||||
):
|
||||
if self.supports_mm_inputs:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||
|
||||
@ -3441,7 +3441,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
self.drafter.dummy_run(num_tokens)
|
||||
use_cudagraphs = cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
|
||||
self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs)
|
||||
|
||||
# This is necessary to avoid blocking DP.
|
||||
# For dummy runs, we typically skip EPLB since we don't have any real
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user