mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 21:40:32 +08:00
Support Longchat and RoPE scaling (#555)
Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
cf5cb1e33e
commit
21877b0d75
@ -351,6 +351,17 @@ def _get_and_verify_max_len(
|
|||||||
if max_len_key is not None:
|
if max_len_key is not None:
|
||||||
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||||
|
|
||||||
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
|
if rope_scaling is not None:
|
||||||
|
if derived_max_model_len == float("inf"):
|
||||||
|
raise ValueError(
|
||||||
|
"When using rope_scaling, the model's config.json must "
|
||||||
|
"contain one of the following keys to determine the original "
|
||||||
|
f"maximum length of the model: {possible_keys}")
|
||||||
|
assert "factor" in rope_scaling
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
derived_max_model_len *= scaling_factor
|
||||||
|
|
||||||
if max_model_len is None:
|
if max_model_len is None:
|
||||||
max_model_len = derived_max_model_len
|
max_model_len = derived_max_model_len
|
||||||
elif max_model_len > derived_max_model_len:
|
elif max_model_len > derived_max_model_len:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""Multi-head attention."""
|
"""Multi-head attention."""
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -9,8 +9,10 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
|||||||
|
|
||||||
from vllm import attention_ops
|
from vllm import attention_ops
|
||||||
from vllm import cache_ops
|
from vllm import cache_ops
|
||||||
from vllm import pos_encoding_ops
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
|
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
|
||||||
|
RotaryEmbedding)
|
||||||
|
|
||||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||||
|
|
||||||
@ -247,7 +249,7 @@ class PagedAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PagedAttentionWithRoPE(PagedAttention):
|
class PagedAttentionWithRoPE(PagedAttention):
|
||||||
"""PagedAttention with rotary embedding."""
|
"""PagedAttention with rotary positional embedding."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -259,34 +261,26 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
base: int = 10000,
|
base: int = 10000,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: Optional[int] = None,
|
||||||
is_neox_style: bool = True,
|
is_neox_style: bool = True,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||||
self.is_neox_style = is_neox_style
|
if rope_scaling is None:
|
||||||
|
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
|
||||||
# Create the cos and sin cache.
|
max_position, base,
|
||||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
is_neox_style)
|
||||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
else:
|
||||||
# avoid numerical issues with large base values (e.g., 10000000).
|
scaling_type = rope_scaling["type"]
|
||||||
# This may cause a slight numerical difference between the HF
|
scaling_factor = rope_scaling["factor"]
|
||||||
# implementation and ours.
|
if scaling_type == "linear":
|
||||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
self.rotary_emb = LinearScalingRotaryEmbedding(
|
||||||
# use CPU to compute the cache and then move it to GPU. However, we
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
# create the cache on GPU for faster initialization. This may cause
|
scaling_factor)
|
||||||
# a slight numerical difference between the HF implementation and ours.
|
elif scaling_type == "dynamic":
|
||||||
inv_freq = 1.0 / (base**(torch.arange(
|
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||||
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
|
head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
t = torch.arange(max_position, dtype=torch.float, device="cuda")
|
scaling_factor)
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
else:
|
||||||
cos = freqs.cos()
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
sin = freqs.sin()
|
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
|
||||||
|
|
||||||
# FIXME(woosuk): This assumes that we configure the default dtype when
|
|
||||||
# initializing the model.
|
|
||||||
torch_dtype = torch.get_default_dtype()
|
|
||||||
cache = cache.to(torch_dtype)
|
|
||||||
# Embedding size: [max_position, rotary_dim]
|
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -303,7 +297,7 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
positions: shape = [num_tokens]
|
positions: shape = [num_tokens]
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||||
@ -319,14 +313,7 @@ class PagedAttentionWithRoPE(PagedAttention):
|
|||||||
|
|
||||||
# Apply rotary embedding to the query and key before passing them
|
# Apply rotary embedding to the query and key before passing them
|
||||||
# to the attention op.
|
# to the attention op.
|
||||||
pos_encoding_ops.rotary_embedding(
|
query, key = self.rotary_emb(positions, query, key)
|
||||||
positions,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
self.head_size,
|
|
||||||
self.cos_sin_cache,
|
|
||||||
self.is_neox_style,
|
|
||||||
)
|
|
||||||
return super().forward(
|
return super().forward(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
|
|||||||
169
vllm/model_executor/layers/rotary_embedding.py
Normal file
169
vllm/model_executor/layers/rotary_embedding.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Rotary Positional Embeddings."""
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm import pos_encoding_ops
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
"""Original rotary positional embedding."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.head_size = head_size
|
||||||
|
self.rotary_dim = rotary_dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
self.is_neox_style = is_neox_style
|
||||||
|
|
||||||
|
cache = self._compute_cos_sin_cache()
|
||||||
|
cache = cache.to(torch.get_default_dtype())
|
||||||
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||||
|
"""Compute the inverse frequency."""
|
||||||
|
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||||
|
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||||
|
# avoid numerical issues with large base values (e.g., 10000000).
|
||||||
|
# This may cause a slight numerical difference between the HF
|
||||||
|
# implementation and ours.
|
||||||
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
|
# use CPU to compute the cache and then move it to GPU. However, we
|
||||||
|
# create the cache on GPU for faster initialization. This may cause
|
||||||
|
# a slight numerical difference between the HF implementation and ours.
|
||||||
|
inv_freq = 1.0 / (base**(torch.arange(
|
||||||
|
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||||
|
self.rotary_dim))
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
"""Compute the cos and sin cache."""
|
||||||
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
|
t = torch.arange(self.max_position_embeddings,
|
||||||
|
dtype=torch.float,
|
||||||
|
device="cuda")
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# pos_encoding_ops.rotary_embedding() is an in-place operation that
|
||||||
|
# updates the query and key tensors.
|
||||||
|
pos_encoding_ops.rotary_embedding(positions, query, key,
|
||||||
|
self.head_size, self.cos_sin_cache,
|
||||||
|
self.is_neox_style)
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with linear scaling.
|
||||||
|
|
||||||
|
Credits to the Reddit user /u/kaiokendev
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
scaling_factor: float,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style)
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
inv_freq = self._compute_inv_freq(self.base)
|
||||||
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
|
# maximum length before applying the rope scaling.
|
||||||
|
# Thus, the maximum length after applying the rope scaling is
|
||||||
|
# self.max_position_embeddings * self.scaling_factor.
|
||||||
|
max_len = self.max_position_embeddings * self.scaling_factor
|
||||||
|
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
||||||
|
|
||||||
|
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
scaling_factor: float,
|
||||||
|
) -> None:
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
|
||||||
|
is_neox_style)
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||||
|
# NOTE(woosuk): self.max_position_embeddings is the original
|
||||||
|
# maximum length before applying the rope scaling.
|
||||||
|
# Thus, the maximum length after applying the rope scaling is
|
||||||
|
# self.max_position_embeddings * self.scaling_factor.
|
||||||
|
max_len = self.max_position_embeddings * self.scaling_factor
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * max_len / self.max_position_embeddings) -
|
||||||
|
(self.scaling_factor - 1))**(self.rotary_dim /
|
||||||
|
(self.rotary_dim - 2))
|
||||||
|
inv_freq = self._compute_inv_freq(base)
|
||||||
|
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||||
|
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos()
|
||||||
|
sin = freqs.sin()
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
@ -25,7 +25,7 @@
|
|||||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||||
InputMetadata to extract the original 2D shape of the input.
|
InputMetadata to extract the original 2D shape of the input.
|
||||||
"""
|
"""
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -92,6 +92,7 @@ class LlamaAttention(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
max_position_embeddings: int = 8192,
|
max_position_embeddings: int = 8192,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -135,7 +136,8 @@ class LlamaAttention(nn.Module):
|
|||||||
base=self.rope_theta,
|
base=self.rope_theta,
|
||||||
max_position=self.max_position_embeddings,
|
max_position=self.max_position_embeddings,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
num_kv_heads=self.num_kv_heads)
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
rope_scaling=rope_scaling)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -165,6 +167,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
# Requires transformers > 4.32.0
|
# Requires transformers > 4.32.0
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
8192)
|
8192)
|
||||||
self.self_attn = LlamaAttention(
|
self.self_attn = LlamaAttention(
|
||||||
@ -172,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user