From ebede26ebf9c142465e1bfa8930fc8d2cbf5d953 Mon Sep 17 00:00:00 2001 From: Jie Li Date: Fri, 8 Dec 2023 00:32:08 +0800 Subject: [PATCH] Make InternLM follow `rope_scaling` in `config.json` (#1956) Co-authored-by: lijie8 --- vllm/model_executor/models/internlm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index ebb96c75736cd..ba28ff8d140af 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch from torch import nn @@ -67,6 +67,7 @@ class InternLMAttention(nn.Module): rope_theta: float = 10000, max_position_embeddings: int = 8192, linear_method: Optional[LinearMethodBase] = None, + rope_scaling: Optional[Dict[str, Any]] = None, ): super().__init__() self.hidden_size = hidden_size @@ -99,6 +100,7 @@ class InternLMAttention(nn.Module): rotary_dim=self.head_dim, max_position=self.max_position_embeddings, base=self.rope_theta, + rope_scaling=rope_scaling, ) self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) @@ -139,6 +141,7 @@ class InternLMDecoderLayer(nn.Module): rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, linear_method=linear_method, + rope_scaling=getattr(config, "rope_scaling", None), ) self.mlp = InternLMMLP( hidden_size=self.hidden_size,