Add: Support for multiple hidden layers in Eagle3 (#26164)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli 2025-10-09 13:00:50 +05:30 committed by GitHub
parent b960441812
commit cf4cd6c24f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 13 deletions

View File

@ -22,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier", id="qwen3-eagle3-speculator-w4a16-verifier",
), ),
pytest.param(
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
id="llama3-eagl3-multiple-layers",
),
], ],
) )
def test_eagle3_speculators_model( def test_eagle3_speculators_model(

View File

@ -34,15 +34,20 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
config: Optional[LlamaConfig] = None, config: Optional[LlamaConfig] = None,
layer_idx: int = 0,
) -> None: ) -> None:
super().__init__(vllm_config, prefix=prefix, config=config) super().__init__(vllm_config, prefix=prefix, config=config)
config = config or vllm_config.model_config.hf_config config = config or vllm_config.model_config.hf_config
quant_config = self.get_quant_config(vllm_config) quant_config = self.get_quant_config(vllm_config)
# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
# Subsequent layers use hidden_size (only hidden_states, no embeds)
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
# override qkv # override qkv
self.self_attn.qkv_proj = QKVParallelLinear( self.self_attn.qkv_proj = QKVParallelLinear(
2 * self.hidden_size, qkv_input_size,
self.self_attn.head_dim, self.self_attn.head_dim,
self.self_attn.total_num_heads, self.self_attn.total_num_heads,
self.self_attn.total_num_kv_heads, self.self_attn.total_num_kv_heads,
@ -52,6 +57,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
) )
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layer_idx = layer_idx
if getattr(config, "norm_before_residual", False): if getattr(config, "norm_before_residual", False):
self._residual_norm = self._norm_before_residual self._residual_norm = self._norm_before_residual
@ -90,11 +96,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
embeds = self.input_layernorm(embeds) if self.layer_idx == 0:
# First layer: concatenate embeds with hidden_states
embeds = self.input_layernorm(embeds)
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
else:
# Subsequent layers: process hidden_states and residuals only
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
# Self Attention # Self Attention
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
@ -133,9 +143,11 @@ class LlamaModel(nn.Module):
[ [
LlamaDecoderLayer( LlamaDecoderLayer(
current_vllm_config, current_vllm_config,
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
config=self.config, config=self.config,
layer_idx=layer_idx,
) )
for layer_idx in range(self.config.num_hidden_layers)
] ]
) )
if hasattr(self.config, "target_hidden_size"): if hasattr(self.config, "target_hidden_size"):
@ -166,13 +178,13 @@ class LlamaModel(nn.Module):
assert hidden_states.shape[-1] == input_embeds.shape[-1] assert hidden_states.shape[-1] == input_embeds.shape[-1]
residual = None residual = None
hidden_states, residual = self.layers[0]( for layer in self.layers:
positions, hidden_states, residual = layer(
input_embeds, positions=positions,
hidden_states, embeds=input_embeds,
residual, hidden_states=hidden_states,
) residual=residual,
)
hidden_states, hidden_prenorm = self.norm(hidden_states, residual) hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
return hidden_states, hidden_prenorm return hidden_states, hidden_prenorm