diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index 8f4ba88677734..3617294bd621d 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -12,6 +12,7 @@ from vllm.config import VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM @@ -37,6 +38,17 @@ class LlamaDecoderLayer(LlamaDecoderLayer): del self.input_layernorm self.input_layernorm = nn.Identity() + def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None: + """Use drafter's quantization config instead of verifier's.""" + draft_model_config = vllm_config.speculative_config.draft_model_config + draft_load_config = vllm_config.load_config + + return ( + VllmConfig.get_quantization_config(draft_model_config, draft_load_config) + if draft_model_config + else None + ) + @support_torch_compile class LlamaModel(nn.Module):