mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 01:45:01 +08:00
Eagle: MM Cuda Graphs with MRope (#28896)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
ac10fd3c69
commit
02f5903b84
@ -23,7 +23,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
maybe_remap_kv_scale_name,
|
maybe_remap_kv_scale_name,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
||||||
from vllm.multimodal.inputs import NestedTensors
|
from vllm.multimodal.inputs import NestedTensors
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -121,13 +120,12 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
|
|
||||||
|
|
||||||
@support_torch_compile(
|
@support_torch_compile(
|
||||||
# torch.compile is disabled for multimodal EAGLE3 models due to constraint
|
dynamic_arg_dims={
|
||||||
# violations with dynamic shapes during tensor concatenation operations.
|
"input_ids": 0,
|
||||||
# See: https://github.com/vllm-project/vllm/pull/22872/files#r2362028132
|
"positions": -1,
|
||||||
# Non-multimodal EAGLE3 models can still use torch.compile safely.
|
"hidden_states": 0,
|
||||||
enable_if=lambda vllm_config: not MULTIMODAL_REGISTRY.supports_multimodal_inputs(
|
"input_embeds": 0,
|
||||||
vllm_config.model_config
|
}
|
||||||
),
|
|
||||||
)
|
)
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -116,9 +116,18 @@ class EagleProposer:
|
|||||||
)
|
)
|
||||||
self.uses_mrope = self.vllm_config.model_config.uses_mrope
|
self.uses_mrope = self.vllm_config.model_config.uses_mrope
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
# M-RoPE need (3, max_num_tokens)
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||||
|
# position on purpose to make it non-contiguous so that it can work
|
||||||
|
# with torch compile.
|
||||||
|
# See detailed explanation in https://github.com/vllm-project/vllm/pull/12128#discussion_r1926431923
|
||||||
|
|
||||||
|
# NOTE: When M-RoPE is enabled, position ids are 3D regardless of
|
||||||
|
# the modality of inputs. For text-only inputs, each dimension has
|
||||||
|
# identical position IDs, making M-RoPE functionally equivalent to
|
||||||
|
# 1D-RoPE.
|
||||||
|
# See page 5 of https://arxiv.org/abs/2409.12191
|
||||||
self.mrope_positions = torch.zeros(
|
self.mrope_positions = torch.zeros(
|
||||||
(3, self.max_num_tokens), dtype=torch.int64, device=device
|
(3, self.max_num_tokens + 1), dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# RoPE need (max_num_tokens,)
|
# RoPE need (max_num_tokens,)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user