mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 02:43:09 +08:00
Add: Support for multiple hidden layers in Eagle3 (#26164)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
parent
b960441812
commit
cf4cd6c24f
@ -22,6 +22,10 @@ from vllm.model_executor.models.interfaces import supports_eagle3
|
|||||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
||||||
id="qwen3-eagle3-speculator-w4a16-verifier",
|
id="qwen3-eagle3-speculator-w4a16-verifier",
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
"nm-testing/random-weights-llama3.1.8b-2layer-eagle3",
|
||||||
|
id="llama3-eagl3-multiple-layers",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_eagle3_speculators_model(
|
def test_eagle3_speculators_model(
|
||||||
|
|||||||
@ -34,15 +34,20 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
config: Optional[LlamaConfig] = None,
|
config: Optional[LlamaConfig] = None,
|
||||||
|
layer_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(vllm_config, prefix=prefix, config=config)
|
super().__init__(vllm_config, prefix=prefix, config=config)
|
||||||
|
|
||||||
config = config or vllm_config.model_config.hf_config
|
config = config or vllm_config.model_config.hf_config
|
||||||
quant_config = self.get_quant_config(vllm_config)
|
quant_config = self.get_quant_config(vllm_config)
|
||||||
|
|
||||||
|
# First layer uses 2*hidden_size (embeds + hidden_states concatenated)
|
||||||
|
# Subsequent layers use hidden_size (only hidden_states, no embeds)
|
||||||
|
qkv_input_size = 2 * self.hidden_size if layer_idx == 0 else self.hidden_size
|
||||||
|
|
||||||
# override qkv
|
# override qkv
|
||||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||||
2 * self.hidden_size,
|
qkv_input_size,
|
||||||
self.self_attn.head_dim,
|
self.self_attn.head_dim,
|
||||||
self.self_attn.total_num_heads,
|
self.self_attn.total_num_heads,
|
||||||
self.self_attn.total_num_kv_heads,
|
self.self_attn.total_num_kv_heads,
|
||||||
@ -52,6 +57,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
if getattr(config, "norm_before_residual", False):
|
if getattr(config, "norm_before_residual", False):
|
||||||
self._residual_norm = self._norm_before_residual
|
self._residual_norm = self._norm_before_residual
|
||||||
@ -90,11 +96,15 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
embeds = self.input_layernorm(embeds)
|
if self.layer_idx == 0:
|
||||||
|
# First layer: concatenate embeds with hidden_states
|
||||||
|
embeds = self.input_layernorm(embeds)
|
||||||
|
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
|
||||||
|
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||||
|
else:
|
||||||
|
# Subsequent layers: process hidden_states and residuals only
|
||||||
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
|
|
||||||
|
|
||||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@ -133,9 +143,11 @@ class LlamaModel(nn.Module):
|
|||||||
[
|
[
|
||||||
LlamaDecoderLayer(
|
LlamaDecoderLayer(
|
||||||
current_vllm_config,
|
current_vllm_config,
|
||||||
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
|
||||||
config=self.config,
|
config=self.config,
|
||||||
|
layer_idx=layer_idx,
|
||||||
)
|
)
|
||||||
|
for layer_idx in range(self.config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if hasattr(self.config, "target_hidden_size"):
|
if hasattr(self.config, "target_hidden_size"):
|
||||||
@ -166,13 +178,13 @@ class LlamaModel(nn.Module):
|
|||||||
assert hidden_states.shape[-1] == input_embeds.shape[-1]
|
assert hidden_states.shape[-1] == input_embeds.shape[-1]
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
hidden_states, residual = self.layers[0](
|
for layer in self.layers:
|
||||||
positions,
|
hidden_states, residual = layer(
|
||||||
input_embeds,
|
positions=positions,
|
||||||
hidden_states,
|
embeds=input_embeds,
|
||||||
residual,
|
hidden_states=hidden_states,
|
||||||
)
|
residual=residual,
|
||||||
|
)
|
||||||
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
|
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
|
||||||
return hidden_states, hidden_prenorm
|
return hidden_states, hidden_prenorm
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user