diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 93ab499e64a2..7d8c950cb2a9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -85,6 +85,7 @@ class LlamaAttention(nn.Module): hidden_size: int, num_heads: int, num_kv_heads: int, + rope_theta: float = 10000, ): super().__init__() self.hidden_size = hidden_size @@ -99,6 +100,7 @@ class LlamaAttention(nn.Module): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta self.qkv_proj = ColumnParallelLinear( hidden_size, @@ -118,6 +120,7 @@ class LlamaAttention(nn.Module): self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.scaling, + base=self.rope_theta, rotary_dim=self.head_dim, num_kv_heads=self.num_kv_heads) @@ -143,10 +146,13 @@ class LlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size,