mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:45:01 +08:00
[Spec-Decode] Support piecewise cudagraphs for Eagle head (#25109)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
da4455609d
commit
29255cfc3b
@ -50,11 +50,14 @@ class CUDAGraphMode(enum.Enum):
|
|||||||
def mixed_mode(self) -> "CUDAGraphMode":
|
def mixed_mode(self) -> "CUDAGraphMode":
|
||||||
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
|
return CUDAGraphMode(self.value[1]) if self.separate_routine() else self
|
||||||
|
|
||||||
|
def has_mode(self, mode: "CUDAGraphMode") -> bool:
|
||||||
|
assert not mode.separate_routine()
|
||||||
|
if self.separate_routine():
|
||||||
|
return mode.value in self.value
|
||||||
|
return self == mode
|
||||||
|
|
||||||
def requires_piecewise_compilation(self) -> bool:
|
def requires_piecewise_compilation(self) -> bool:
|
||||||
return (
|
return self.has_mode(CUDAGraphMode.PIECEWISE)
|
||||||
self.decode_mode() == CUDAGraphMode.PIECEWISE
|
|
||||||
or self.mixed_mode() == CUDAGraphMode.PIECEWISE
|
|
||||||
)
|
|
||||||
|
|
||||||
def max_cudagraph_mode(self) -> "CUDAGraphMode":
|
def max_cudagraph_mode(self) -> "CUDAGraphMode":
|
||||||
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
|
return CUDAGraphMode(max(self.value)) if self.separate_routine() else self
|
||||||
|
|||||||
@ -283,6 +283,12 @@ def set_forward_context(
|
|||||||
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
|
vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Convenience: if cudagraph is used and num_tokens is given, we can just
|
||||||
|
# create a batch descriptor here if not given (there's no harm since if it
|
||||||
|
# doesn't match in the wrapper it'll fall through).
|
||||||
|
if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
|
||||||
|
batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)
|
||||||
|
|
||||||
forward_context = create_forward_context(
|
forward_context = create_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
vllm_config,
|
vllm_config,
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -162,6 +163,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class DeepSeekMTP(nn.Module, SupportsPP):
|
class DeepSeekMTP(nn.Module, SupportsPP):
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -21,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import NestedTensors
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
|
|
||||||
from .utils import AutoWeightsLoader, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
@ -119,6 +121,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile(
|
||||||
|
# torch.compile is disabled for multimodal EAGLE3 models due to constraint
|
||||||
|
# violations with dynamic shapes during tensor concatenation operations.
|
||||||
|
# See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
|
||||||
|
# Non-multimodal EAGLE3 models can still use torch.compile safely.
|
||||||
|
enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
|
||||||
|
vllm_config.model_config
|
||||||
|
),
|
||||||
|
)
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -9,7 +9,12 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config
|
from vllm.config import (
|
||||||
|
CompilationLevel,
|
||||||
|
CUDAGraphMode,
|
||||||
|
VllmConfig,
|
||||||
|
get_layers_from_vllm_config,
|
||||||
|
)
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -80,12 +85,25 @@ class EagleProposer:
|
|||||||
self.attn_layer_names: list[str] = []
|
self.attn_layer_names: list[str] = []
|
||||||
self.indexer_layer_names: list[str] = []
|
self.indexer_layer_names: list[str] = []
|
||||||
|
|
||||||
|
self.use_cuda_graph = False
|
||||||
|
|
||||||
|
compilation_config = self.vllm_config.compilation_config
|
||||||
|
if compilation_config.level == CompilationLevel.PIECEWISE:
|
||||||
|
cudagraph_mode = compilation_config.cudagraph_mode
|
||||||
|
if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode(
|
||||||
|
CUDAGraphMode.PIECEWISE
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Currently the eagle proposer only supports cudagraph_mode "
|
||||||
|
"PIECEWISE, if you want the drafter to use cuda graphs, "
|
||||||
|
"please set compilation_config.cudagraph_mode to PIECEWISE "
|
||||||
|
"or FULL_AND_PIECEWISE"
|
||||||
|
)
|
||||||
self.use_cuda_graph = (
|
self.use_cuda_graph = (
|
||||||
not current_platform.is_xpu()
|
cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE)
|
||||||
and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
|
|
||||||
and not self.vllm_config.model_config.enforce_eager
|
|
||||||
and not self.speculative_config.enforce_eager
|
and not self.speculative_config.enforce_eager
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cudagraph_batch_sizes = (
|
self.cudagraph_batch_sizes = (
|
||||||
list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||||
if self.use_cuda_graph
|
if self.use_cuda_graph
|
||||||
@ -239,12 +257,15 @@ class EagleProposer:
|
|||||||
per_layer_attn_metadata = {}
|
per_layer_attn_metadata = {}
|
||||||
for layer_name in self.attn_layer_names:
|
for layer_name in self.attn_layer_names:
|
||||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
|
|
||||||
for layer_name in self.indexer_layer_names:
|
for layer_name in self.indexer_layer_names:
|
||||||
assert draft_indexer_metadata is not None
|
assert draft_indexer_metadata is not None
|
||||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||||
|
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
num_input_tokens = num_tokens
|
num_input_tokens = num_tokens
|
||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
@ -267,7 +288,10 @@ class EagleProposer:
|
|||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
|
per_layer_attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
):
|
):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -326,8 +350,10 @@ class EagleProposer:
|
|||||||
|
|
||||||
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
|
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
input_batch_size = batch_size
|
input_batch_size = batch_size
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
common_attn_metadata.num_actual_tokens = batch_size
|
common_attn_metadata.num_actual_tokens = batch_size
|
||||||
common_attn_metadata.max_query_len = 1
|
common_attn_metadata.max_query_len = 1
|
||||||
@ -424,7 +450,10 @@ class EagleProposer:
|
|||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
|
per_layer_attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=input_batch_size,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
):
|
):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@ -731,11 +760,16 @@ class EagleProposer:
|
|||||||
|
|
||||||
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
else:
|
else:
|
||||||
num_input_tokens = num_tokens
|
num_input_tokens = num_tokens
|
||||||
|
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(
|
with set_forward_context(
|
||||||
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
|
per_layer_attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens,
|
||||||
|
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||||
):
|
):
|
||||||
last_hidden_states, hidden_states = self.model(
|
last_hidden_states, hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
@ -1015,8 +1049,19 @@ class EagleProposer:
|
|||||||
def dummy_run(
|
def dummy_run(
|
||||||
self,
|
self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
|
use_cudagraphs=True,
|
||||||
) -> None:
|
) -> None:
|
||||||
with set_forward_context(None, self.vllm_config, num_tokens=num_tokens):
|
if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||||
|
num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
|
|
||||||
|
with set_forward_context(
|
||||||
|
None,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE
|
||||||
|
if use_cudagraphs
|
||||||
|
else CUDAGraphMode.NONE,
|
||||||
|
):
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:num_tokens]
|
||||||
|
|||||||
@ -3441,7 +3441,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
if self.speculative_config and self.speculative_config.use_eagle():
|
if self.speculative_config and self.speculative_config.use_eagle():
|
||||||
assert isinstance(self.drafter, EagleProposer)
|
assert isinstance(self.drafter, EagleProposer)
|
||||||
self.drafter.dummy_run(num_tokens)
|
use_cudagraphs = cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs)
|
||||||
|
|
||||||
# This is necessary to avoid blocking DP.
|
# This is necessary to avoid blocking DP.
|
||||||
# For dummy runs, we typically skip EPLB since we don't have any real
|
# For dummy runs, we typically skip EPLB since we don't have any real
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user