diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 7ed757fd59b0..8046252c0b86 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -50,11 +50,14 @@ class CUDAGraphMode(enum.Enum): def mixed_mode(self) -> "CUDAGraphMode": return CUDAGraphMode(self.value[1]) if self.separate_routine() else self + def has_mode(self, mode: "CUDAGraphMode") -> bool: + assert not mode.separate_routine() + if self.separate_routine(): + return mode.value in self.value + return self == mode + def requires_piecewise_compilation(self) -> bool: - return ( - self.decode_mode() == CUDAGraphMode.PIECEWISE - or self.mixed_mode() == CUDAGraphMode.PIECEWISE - ) + return self.has_mode(CUDAGraphMode.PIECEWISE) def max_cudagraph_mode(self) -> "CUDAGraphMode": return CUDAGraphMode(max(self.value)) if self.separate_routine() else self diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 09da1398b030..37831e02f53f 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -283,6 +283,12 @@ def set_forward_context( vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp ) + # Convenience: if cudagraph is used and num_tokens is given, we can just + # create a batch descriptor here if not given (there's no harm since if it + # doesn't match in the wrapper it'll fall through). + if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None: + batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens) + forward_context = create_forward_context( attn_metadata, vllm_config, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 041dd6db7325..bf3ab7bb3079 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from transformers import PretrainedConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -162,6 +163,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): return logits +@support_torch_compile class DeepSeekMTP(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 71f7274d2d64..67d466989919 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn from transformers import LlamaConfig +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -21,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors from .utils import AutoWeightsLoader, maybe_prefix @@ -119,6 +121,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer): return hidden_states, residual +@support_torch_compile( + # torch.compile is disabled for multimodal EAGLE3 models due to constraint + # violations with dynamic shapes during tensor concatenation operations. + # See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132 + # Non-multimodal EAGLE3 models can still use torch.compile safely. + enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs( + vllm_config.model_config + ), +) class LlamaModel(nn.Module): def __init__( self, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 356ddfc9d986..6e88664f007d 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,7 +9,12 @@ import numpy as np import torch import torch.nn as nn -from vllm.config import CompilationLevel, VllmConfig, get_layers_from_vllm_config +from vllm.config import ( + CompilationLevel, + CUDAGraphMode, + VllmConfig, + get_layers_from_vllm_config, +) from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger @@ -80,12 +85,25 @@ class EagleProposer: self.attn_layer_names: list[str] = [] self.indexer_layer_names: list[str] = [] - self.use_cuda_graph = ( - not current_platform.is_xpu() - and self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE - and not self.vllm_config.model_config.enforce_eager - and not self.speculative_config.enforce_eager - ) + self.use_cuda_graph = False + + compilation_config = self.vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE: + cudagraph_mode = compilation_config.cudagraph_mode + if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( + CUDAGraphMode.PIECEWISE + ): + logger.warning( + "Currently the eagle proposer only supports cudagraph_mode " + "PIECEWISE, if you want the drafter to use cuda graphs, " + "please set compilation_config.cudagraph_mode to PIECEWISE " + "or FULL_AND_PIECEWISE" + ) + self.use_cuda_graph = ( + cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE) + and not self.speculative_config.enforce_eager + ) + self.cudagraph_batch_sizes = ( list(reversed(self.vllm_config.compilation_config.cudagraph_capture_sizes)) if self.use_cuda_graph @@ -239,12 +257,15 @@ class EagleProposer: per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata + for layer_name in self.indexer_layer_names: assert draft_indexer_metadata is not None per_layer_attn_metadata[layer_name] = draft_indexer_metadata + cudagraph_runtime_mode = CUDAGraphMode.NONE if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph @@ -267,7 +288,10 @@ class EagleProposer: inputs_embeds = None with set_forward_context( - per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): ret_hidden_states = self.model( input_ids=input_ids, @@ -326,8 +350,10 @@ class EagleProposer: if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: input_batch_size = batch_size + cudagraph_runtime_mode = CUDAGraphMode.NONE common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 @@ -424,7 +450,10 @@ class EagleProposer: # Run the model. with set_forward_context( - per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size + per_layer_attn_metadata, + self.vllm_config, + num_tokens=input_batch_size, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): ret_hidden_states = self.model( input_ids=input_ids, @@ -731,11 +760,16 @@ class EagleProposer: if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE else: num_input_tokens = num_tokens + cudagraph_runtime_mode = CUDAGraphMode.NONE # Run the model. with set_forward_context( - per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens + per_layer_attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): last_hidden_states, hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], @@ -1015,8 +1049,19 @@ class EagleProposer: def dummy_run( self, num_tokens: int, + use_cudagraphs=True, ) -> None: - with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + if use_cudagraphs and num_tokens <= self.cudagraph_batch_sizes[-1]: + num_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) + + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE + if use_cudagraphs + else CUDAGraphMode.NONE, + ): if self.supports_mm_inputs: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 108d14f17f3b..dce8c650e0eb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3441,7 +3441,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.speculative_config and self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - self.drafter.dummy_run(num_tokens) + use_cudagraphs = cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + self.drafter.dummy_run(num_tokens, use_cudagraphs=use_cudagraphs) # This is necessary to avoid blocking DP. # For dummy runs, we typically skip EPLB since we don't have any real