mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:45:01 +08:00
Eagle: MM Cuda Graphs with MRope (#28896)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ac10fd3c69
commit
02f5903b84
@ -23,7 +23,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import (
|
||||
@ -121,13 +120,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
|
||||
@support_torch_compile(
|
||||
# torch.compile is disabled for multimodal EAGLE3 models due to constraint
|
||||
# violations with dynamic shapes during tensor concatenation operations.
|
||||
# See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
|
||||
# Non-multimodal EAGLE3 models can still use torch.compile safely.
|
||||
enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
|
||||
vllm_config.model_config
|
||||
),
|
||||
dynamic_arg_dims={
|
||||
"input_ids": 0,
|
||||
"positions": -1,
|
||||
"hidden_states": 0,
|
||||
"input_embeds": 0,
|
||||
}
|
||||
)
|
||||
class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
|
||||
@ -116,9 +116,18 @@ class EagleProposer:
|
||||
)
|
||||
self.uses_mrope = self.vllm_config.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
# M-RoPE need (3, max_num_tokens)
|
||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||
# position on purpose to make it non-contiguous so that it can work
|
||||
# with torch compile.
|
||||
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||
|
||||
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||
# the modality of inputs. For text-only inputs, each dimension has
|
||||
# identical position IDs, making M-RoPE functionally equivalent to
|
||||
# 1D-RoPE.
|
||||
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, self.max_num_tokens), dtype=torch.int64, device=device
|
||||
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device
|
||||
)
|
||||
else:
|
||||
# RoPE need (max_num_tokens,)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user