mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-04 07:11:20 +08:00
Make InternLM follow rope_scaling in config.json (#1956)
Co-authored-by: lijie8 <lijie8@sensetime.com>
This commit is contained in:
parent
d940ce497e
commit
ebede26ebf
@ -1,5 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -67,6 +67,7 @@ class InternLMAttention(nn.Module):
|
|||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
linear_method: Optional[LinearMethodBase] = None,
|
linear_method: Optional[LinearMethodBase] = None,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@ -99,6 +100,7 @@ class InternLMAttention(nn.Module):
|
|||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position=self.max_position_embeddings,
|
max_position=self.max_position_embeddings,
|
||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
)
|
)
|
||||||
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
|
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
|
||||||
|
|
||||||
@ -139,6 +141,7 @@ class InternLMDecoderLayer(nn.Module):
|
|||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
linear_method=linear_method,
|
linear_method=linear_method,
|
||||||
|
rope_scaling=getattr(config, "rope_scaling", None),
|
||||||
)
|
)
|
||||||
self.mlp = InternLMMLP(
|
self.mlp = InternLMMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user