mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
Add rope_scaling to Qwen (#1210)
This commit is contained in:
parent
20f7cc4cde
commit
7bedab5748
@ -8,7 +8,7 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -76,13 +76,12 @@ class QWenMLP(nn.Module):
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float = 10000,
|
||||
):
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
@ -116,7 +115,7 @@ class QWenAttention(nn.Module):
|
||||
rotary_dim=self.head_dim,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings,
|
||||
)
|
||||
rope_scaling=rope_scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -144,10 +143,12 @@ class QWenBlock(nn.Module):
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.attn = QWenAttention(config.hidden_size,
|
||||
config.num_attention_heads,
|
||||
config.max_position_embeddings,
|
||||
rope_theta=rope_theta)
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling)
|
||||
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user