Eagle: MM Cuda Graphs with MRope (#28896)

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Izzy Putterman 2025-11-19 12:01:05 -08:00 committed by GitHub
parent ac10fd3c69
commit 02f5903b84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 10 deletions

View File

@ -23,7 +23,6 @@ from vllm.model_executor.model_loader.weight_utils import (
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from .utils import (
@ -121,13 +120,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
@support_torch_compile(
# torch.compile is disabled for multimodal EAGLE3 models due to constraint
# violations with dynamic shapes during tensor concatenation operations.
# See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
# Non-multimodal EAGLE3 models can still use torch.compile safely.
enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
vllm_config.model_config
),
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"hidden_states": 0,
"input_embeds": 0,
}
)
class LlamaModel(nn.Module):
def __init__(

View File

@ -116,9 +116,18 @@ class EagleProposer:
)
self.uses_mrope = self.vllm_config.model_config.uses_mrope
if self.uses_mrope:
# M-RoPE need (3, max_num_tokens)
# NOTE: `mrope_positions` is implemented with one additional dummy
# position on purpose to make it non-contiguous so that it can work
# with torch compile.
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
# the modality of inputs. For text-only inputs, each dimension has
# identical position IDs, making M-RoPE functionally equivalent to
# 1D-RoPE.
# See page 5 of https://arxiv.org/abs/2409.12191
self.mrope_positions = torch.zeros(
(3, self.max_num_tokens), dtype=torch.int64, device=device
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device
)
else:
# RoPE need (max_num_tokens,)