From c33992154a64ae5f536e71981d37f4a8ac0b55c8 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Mon, 29 Sep 2025 21:07:20 +0530 Subject: [PATCH] [Bugfix][Speculative Decoding] Fix Eagle3 quantization config issue (#25883) Signed-off-by: Rahul Tuli Signed-off-by: yewentao256 --- .../speculators/test_eagle3.py | 3 +++ vllm/model_executor/models/llama.py | 7 ++++++- vllm/model_executor/models/llama_eagle3.py | 14 +++++++++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/speculative_decoding/speculators/test_eagle3.py b/tests/speculative_decoding/speculators/test_eagle3.py index 368238b3a7200..87d799a5fed70 100644 --- a/tests/speculative_decoding/speculators/test_eagle3.py +++ b/tests/speculative_decoding/speculators/test_eagle3.py @@ -14,6 +14,9 @@ from vllm.model_executor.models.interfaces import supports_eagle3 pytest.param( "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized", id="qwen3-eagle3-speculator"), + pytest.param( + "nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16", + id="qwen3-eagle3-speculator-w4a16-verifier"), ]) def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path, monkeypatch): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c7dd134ea47e9..a6081d3315118 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -248,7 +248,7 @@ class LlamaDecoderLayer(nn.Module): config = config or vllm_config.model_config.hf_config cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config + quant_config = self.get_quant_config(vllm_config) self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -328,6 +328,11 @@ class LlamaDecoderLayer(nn.Module): hidden_states = self.mlp(hidden_states) return hidden_states, residual + def get_quant_config( + self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: + """Get quantization config for this layer. Override in subclasses.""" + return vllm_config.quant_config + @support_torch_compile class LlamaModel(nn.Module): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7192a76c87498..3fb6f2f8d5ecf 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -13,6 +13,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -33,7 +35,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer): super().__init__(vllm_config, prefix=prefix, config=config) config = config or vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config + quant_config = self.get_quant_config(vllm_config) # override qkv self.self_attn.qkv_proj = QKVParallelLinear( @@ -53,6 +55,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer): else: self._residual_norm = self._norm_after_residual + def get_quant_config( + self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]: + """Use drafter's quantization config instead of verifier's.""" + draft_model_config = vllm_config.speculative_config.draft_model_config + draft_load_config = vllm_config.load_config + + return VllmConfig.get_quantization_config( + draft_model_config, + draft_load_config) if draft_model_config else None + def _norm_before_residual( self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: