mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 15:51:24 +08:00
[Bugfix][SpecDecode] Adjust Eagle model architecture to align with intended design (#11672)
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
This commit is contained in:
parent
899136b857
commit
2118d0565c
@ -17,14 +17,30 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from .utils import maybe_prefix
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
|
class DummyInputLayerNorm(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DummyOutputNorm(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x, residual):
|
||||||
|
if residual is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
class EAGLE(nn.Module):
|
class EAGLE(nn.Module):
|
||||||
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
||||||
Reference implementation: https://github.com/SafeAILab/EAGLE
|
Reference implementation: https://github.com/SafeAILab/EAGLE
|
||||||
|
|
||||||
Differences from reference implementation:
|
Differences from reference implementation:
|
||||||
1. In reference, LlamaDecoderLayer implementation doesn't have
|
1. In reference, LlamaDecoderLayer implementation doesn't have
|
||||||
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
|
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
|
||||||
but we do as HF implementation also does.
|
Following this approach, our implementation also disables
|
||||||
|
the input_layernorm for the first decoder layer.
|
||||||
2. We allow any decoder layer to be used in EAGLE whereas in reference
|
2. We allow any decoder layer to be used in EAGLE whereas in reference
|
||||||
decoder layer is fixed to be LlamaDecoderLayer.
|
decoder layer is fixed to be LlamaDecoderLayer.
|
||||||
3. We have an optional token_map which reduces draft vocab to most
|
3. We have an optional token_map which reduces draft vocab to most
|
||||||
@ -46,10 +62,16 @@ class EAGLE(nn.Module):
|
|||||||
|
|
||||||
self.model = model_cls(vllm_config=vllm_config,
|
self.model = model_cls(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
|
|
||||||
self.fc = nn.Linear(config.model.hidden_size * 2,
|
self.fc = nn.Linear(config.model.hidden_size * 2,
|
||||||
config.model.hidden_size,
|
config.model.hidden_size,
|
||||||
bias=getattr(self.config, "eagle_fc_bias", False))
|
bias=getattr(self.config, "eagle_fc_bias", False))
|
||||||
|
|
||||||
|
# Modify layer normalization and residual connections as suggested
|
||||||
|
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
|
||||||
|
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm()
|
||||||
|
self.model.model.norm = DummyOutputNorm()
|
||||||
|
|
||||||
self.orig_vocab_size = config.vocab_size
|
self.orig_vocab_size = config.vocab_size
|
||||||
self.truncated_vocab_size = config.truncated_vocab_size
|
self.truncated_vocab_size = config.truncated_vocab_size
|
||||||
self.unpadded_vocab_size = self.truncated_vocab_size
|
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user