From 8c546102658f97b10d13bcf25193b65edc6ea6ff Mon Sep 17 00:00:00 2001 From: vllmellm Date: Tue, 16 Sep 2025 12:45:38 +0800 Subject: [PATCH] [Bug] [Spec Dec]: Fix kv_cache dtype mismatch for Eagle3 drafter on FP8 target (#24505) Signed-off-by: vllmellm --- vllm/model_executor/models/llama_eagle3.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index bceb6cc42768..99b77729b501 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -9,7 +9,7 @@ import torch.nn as nn from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -33,10 +33,14 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__(config, quant_config=quant_config, prefix=prefix) + super().__init__(config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) # override qkv self.self_attn.qkv_proj = QKVParallelLinear( @@ -114,6 +118,8 @@ class LlamaModel(nn.Module): speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size + current_vllm_config = get_current_vllm_config() + self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, @@ -123,6 +129,7 @@ class LlamaModel(nn.Module): self.layers = nn.ModuleList([ LlamaDecoderLayer( config=self.config, + cache_config=current_vllm_config.cache_config, prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), ) ])