diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index bceb6cc42768..99b77729b501 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -9,7 +9,7 @@ import torch.nn as nn from transformers import LlamaConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import VllmConfig +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear @@ -33,10 +33,14 @@ class LlamaDecoderLayer(LlamaDecoderLayer): def __init__( self, config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super().__init__(config, quant_config=quant_config, prefix=prefix) + super().__init__(config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) # override qkv self.self_attn.qkv_proj = QKVParallelLinear( @@ -114,6 +118,8 @@ class LlamaModel(nn.Module): speculative_config.draft_model_config.hf_config self.vocab_size = self.config.vocab_size + current_vllm_config = get_current_vllm_config() + self.embed_tokens = VocabParallelEmbedding( self.config.vocab_size, self.config.hidden_size, @@ -123,6 +129,7 @@ class LlamaModel(nn.Module): self.layers = nn.ModuleList([ LlamaDecoderLayer( config=self.config, + cache_config=current_vllm_config.cache_config, prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"), ) ])