diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index f138d13630263..eb7b5af19ae96 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -17,14 +17,30 @@ from vllm.sequence import IntermediateTensors from .utils import maybe_prefix +class DummyInputLayerNorm(nn.Module): + + def forward(self, x): + return x + + +class DummyOutputNorm(nn.Module): + + def forward(self, x, residual): + if residual is None: + return x + else: + return x, residual + + class EAGLE(nn.Module): """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 Reference implementation: https://github.com/SafeAILab/EAGLE Differences from reference implementation: 1. In reference, LlamaDecoderLayer implementation doesn't have - input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427) - but we do as HF implementation also does. + input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427). + Following this approach, our implementation also disables + the input_layernorm for the first decoder layer. 2. We allow any decoder layer to be used in EAGLE whereas in reference decoder layer is fixed to be LlamaDecoderLayer. 3. We have an optional token_map which reduces draft vocab to most @@ -46,10 +62,16 @@ class EAGLE(nn.Module): self.model = model_cls(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + self.fc = nn.Linear(config.model.hidden_size * 2, config.model.hidden_size, bias=getattr(self.config, "eagle_fc_bias", False)) + # Modify layer normalization and residual connections as suggested + # in the EAGLE framework: https://github.com/SafeAILab/EAGLE + self.model.model.layers[0].input_layernorm = DummyInputLayerNorm() + self.model.model.norm = DummyOutputNorm() + self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size