mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 05:35:02 +08:00
[Bugfix][SpecDecode] Adjust Eagle model architecture to align with intended design (#11672)
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
This commit is contained in:
parent
899136b857
commit
2118d0565c
@ -17,14 +17,30 @@ from vllm.sequence import IntermediateTensors
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class DummyInputLayerNorm(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class DummyOutputNorm(nn.Module):
|
||||
|
||||
def forward(self, x, residual):
|
||||
if residual is None:
|
||||
return x
|
||||
else:
|
||||
return x, residual
|
||||
|
||||
|
||||
class EAGLE(nn.Module):
|
||||
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
||||
Reference implementation: https://github.com/SafeAILab/EAGLE
|
||||
|
||||
Differences from reference implementation:
|
||||
1. In reference, LlamaDecoderLayer implementation doesn't have
|
||||
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
|
||||
but we do as HF implementation also does.
|
||||
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
|
||||
Following this approach, our implementation also disables
|
||||
the input_layernorm for the first decoder layer.
|
||||
2. We allow any decoder layer to be used in EAGLE whereas in reference
|
||||
decoder layer is fixed to be LlamaDecoderLayer.
|
||||
3. We have an optional token_map which reduces draft vocab to most
|
||||
@ -46,10 +62,16 @@ class EAGLE(nn.Module):
|
||||
|
||||
self.model = model_cls(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.fc = nn.Linear(config.model.hidden_size * 2,
|
||||
config.model.hidden_size,
|
||||
bias=getattr(self.config, "eagle_fc_bias", False))
|
||||
|
||||
# Modify layer normalization and residual connections as suggested
|
||||
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
|
||||
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm()
|
||||
self.model.model.norm = DummyOutputNorm()
|
||||
|
||||
self.orig_vocab_size = config.vocab_size
|
||||
self.truncated_vocab_size = config.truncated_vocab_size
|
||||
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user