diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 75c671311b49..3eaf2d80082f 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -23,7 +23,6 @@ from vllm.model_executor.model_loader.weight_utils import ( maybe_remap_kv_scale_name, ) from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM -from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors from .utils import ( @@ -121,13 +120,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer): @support_torch_compile( - # torch.compile is disabled for multimodal EAGLE3 models due to constraint - # violations with dynamic shapes during tensor concatenation operations. - # See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132 - # Non-multimodal EAGLE3 models can still use torch.compile safely. - enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs( - vllm_config.model_config - ), + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "hidden_states": 0, + "input_embeds": 0, + } ) class LlamaModel(nn.Module): def __init__( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 5bf2503c3027..406bb696bd4c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -116,9 +116,18 @@ class EagleProposer: ) self.uses_mrope = self.vllm_config.model_config.uses_mrope if self.uses_mrope: - # M-RoPE need (3, max_num_tokens) + # NOTE: `mrope_positions` is implemented with one additional dummy + # position on purpose to make it non-contiguous so that it can work + # with torch compile. + # See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923 + + # NOTE: When M-RoPE is enabled, position ids are 3D regardless of + # the modality of inputs. For text-only inputs, each dimension has + # identical position IDs, making M-RoPE functionally equivalent to + # 1D-RoPE. + # See page 5 of https://arxiv.org/abs/2409.12191 self.mrope_positions = torch.zeros( - (3, self.max_num_tokens), dtype=torch.int64, device=device + (3, self.max_num_tokens + 1), dtype=torch.int64, device=device ) else: # RoPE need (max_num_tokens,)