mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 18:35:58 +08:00
[Bugfix][Speculative Decoding] Fix Eagle3 quantization config issue (#25883)
Signed-off-by: Rahul Tuli <rtuli@redhat.com>
This commit is contained in:
parent
d0d138bc55
commit
145ac73317
@ -14,6 +14,9 @@ from vllm.model_executor.models.interfaces import supports_eagle3
|
|||||||
pytest.param(
|
pytest.param(
|
||||||
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
|
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
|
||||||
id="qwen3-eagle3-speculator"),
|
id="qwen3-eagle3-speculator"),
|
||||||
|
pytest.param(
|
||||||
|
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized-w4a16",
|
||||||
|
id="qwen3-eagle3-speculator-w4a16-verifier"),
|
||||||
])
|
])
|
||||||
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
|
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
|
||||||
monkeypatch):
|
monkeypatch):
|
||||||
|
|||||||
@ -248,7 +248,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
config = config or vllm_config.model_config.hf_config
|
config = config or vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = self.get_quant_config(vllm_config)
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
@ -328,6 +328,11 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
def get_quant_config(
|
||||||
|
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
|
||||||
|
"""Get quantization config for this layer. Override in subclasses."""
|
||||||
|
return vllm_config.quant_config
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
@support_torch_compile
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|||||||
@ -13,6 +13,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -33,7 +35,7 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
super().__init__(vllm_config, prefix=prefix, config=config)
|
super().__init__(vllm_config, prefix=prefix, config=config)
|
||||||
|
|
||||||
config = config or vllm_config.model_config.hf_config
|
config = config or vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = self.get_quant_config(vllm_config)
|
||||||
|
|
||||||
# override qkv
|
# override qkv
|
||||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||||
@ -53,6 +55,16 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
else:
|
else:
|
||||||
self._residual_norm = self._norm_after_residual
|
self._residual_norm = self._norm_after_residual
|
||||||
|
|
||||||
|
def get_quant_config(
|
||||||
|
self, vllm_config: VllmConfig) -> Optional[QuantizationConfig]:
|
||||||
|
"""Use drafter's quantization config instead of verifier's."""
|
||||||
|
draft_model_config = vllm_config.speculative_config.draft_model_config
|
||||||
|
draft_load_config = vllm_config.load_config
|
||||||
|
|
||||||
|
return VllmConfig.get_quantization_config(
|
||||||
|
draft_model_config,
|
||||||
|
draft_load_config) if draft_model_config else None
|
||||||
|
|
||||||
def _norm_before_residual(
|
def _norm_before_residual(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user