diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index c58fc8c0dc5f..c46ac7a88b75 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -6,11 +6,21 @@ import torch @pytest.mark.parametrize( "model_path", - [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"), - ("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")]) + [("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")]) def test_llama(vllm_runner, example_prompts, model_path): with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens=20) print(vllm_outputs) assert vllm_outputs + + +@pytest.mark.parametrize( + "model_path", + [("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")]) +def test_qwen(vllm_runner, example_prompts, model_path): + with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens=20) + print(vllm_outputs) + assert vllm_outputs diff --git a/vllm/config.py b/vllm/config.py index dabb4b524dfd..95dae4275edf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3175,10 +3175,19 @@ class SpeculativeConfig: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - if self.method == "eagle3" and self.target_model_config and \ - "llama" not in self.target_model_config.hf_text_config.model_type: + from vllm.transformers_utils.configs import SpeculatorsConfig + + eagle3_target_supported = ["llama"] + if self.draft_model_config and isinstance( + self.draft_model_config.hf_config, SpeculatorsConfig): + eagle3_target_supported.append("qwen") + + if self.method == "eagle3" and self.target_model_config and not any( + supported_model in + self.target_model_config.hf_text_config.model_type + for supported_model in eagle3_target_supported): raise ValueError( - "Eagle3 is only supported for Llama models. " + f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 f"Got {self.target_model_config.hf_text_config.model_type=}") return self diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 23f65b99c22c..0e7507a4570b 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -330,6 +330,8 @@ class Qwen2Model(nn.Module): else: self.norm = PPMissingLayer() + self.aux_hidden_state_layers: tuple[int] = tuple() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -350,18 +352,25 @@ class Qwen2Model(nn.Module): assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states, residual = layer( - positions, - hidden_states, - residual, - ) + + aux_hidden_states = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) + hidden_states, residual = layer(positions, hidden_states, residual) + if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual }) + hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states def load_weights(self, weights: Iterable[tuple[str, diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 393ce41a91a0..d2ae8959b103 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -288,6 +288,13 @@ class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids)