[Bugfix][SpecDecode] Adjust Eagle model architecture to align with intended design (#11672)

Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
This commit is contained in:
Sungjae Lee 2025-01-11 13:49:38 +09:00 committed by GitHub
parent 899136b857
commit 2118d0565c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,14 +17,30 @@ from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
class DummyInputLayerNorm(nn.Module):
def forward(self, x):
return x
class DummyOutputNorm(nn.Module):
def forward(self, x, residual):
if residual is None:
return x
else:
return x, residual
class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
Following this approach, our implementation also disables
the input_layernorm for the first decoder layer.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
@ -46,10 +62,16 @@ class EAGLE(nn.Module):
self.model = model_cls(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=getattr(self.config, "eagle_fc_bias", False))
# Modify layer normalization and residual connections as suggested
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm()
self.model.model.norm = DummyOutputNorm()
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size