[Bugfix][Speculative Decoding] Fix Eagle3 quantization config issue (#25883)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
Rahul Tuli 2025-09-29 21:07:20 +05:30 committed by GitHub
parent d0d138bc55
commit 145ac73317
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 22 additions and 2 deletions

View File

@ -14,6 +14,9 @@ from vllm.model_executor.models.interfaces import supports_eagle3
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
id="qwen3-eagle3-speculator"),
pytest.param(
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
id="qwen3-eagle3-speculator-w4a16-verifier"),
])
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
monkeypatch):

View File

@ -248,7 +248,7 @@ class LlamaDecoderLayer(nn.Module):
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
quant_config = self.get_quant_config(vllm_config)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
@ -328,6 +328,11 @@ class LlamaDecoderLayer(nn.Module):
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def get_quant_config(
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
"""Get quantization config for this layer. Override in subclasses."""
return vllm_config.quant_config
@support_torch_compile
class LlamaModel(nn.Module):

View File

@ -13,6 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -33,7 +35,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
super().__init__(vllm_config, prefix=prefix, config=config)
config = config or vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
quant_config = self.get_quant_config(vllm_config)
# override qkv
self.self_attn.qkv_proj = QKVParallelLinear(
@ -53,6 +55,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
else:
self._residual_norm = self._norm_after_residual
def get_quant_config(
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
"""Use drafter's quantization config instead of verifier's."""
draft_model_config = vllm_config.speculative_config.draft_model_config
draft_load_config = vllm_config.load_config
return VllmConfig.get_quantization_config(
draft_model_config,
draft_load_config) if draft_model_config else None
def _norm_before_residual(
self,
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: