Make InternLM follow rope_scaling in config.json (#1956)

Co-authored-by: lijie8 <lijie8@sensetime.com>
This commit is contained in:
Jie Li 2023-12-08 00:32:08 +08:00 committed by GitHub
parent d940ce497e
commit ebede26ebf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
@ -67,6 +67,7 @@ class InternLMAttention(nn.Module):
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.hidden_size = hidden_size
@ -99,6 +100,7 @@ class InternLMAttention(nn.Module):
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
@ -139,6 +141,7 @@ class InternLMDecoderLayer(nn.Module):
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
rope_scaling=getattr(config, "rope_scaling", None),
)
self.mlp = InternLMMLP(
hidden_size=self.hidden_size,