Support Longchat and RoPE scaling (#555)

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Lily Liu 2023-09-27 03:36:02 -07:00 committed by GitHub
parent cf5cb1e33e
commit 21877b0d75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 211 additions and 40 deletions

View File

@ -351,6 +351,17 @@ def _get_and_verify_max_len(
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if derived_max_model_len == float("inf"):
raise ValueError(
"When using rope_scaling, the model's config.json must "
"contain one of the following keys to determine the original "
f"maximum length of the model: {possible_keys}")
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
derived_max_model_len *= scaling_factor
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:

View File

@ -1,5 +1,5 @@
"""Multi-head attention."""
from typing import List, Optional
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
@ -9,8 +9,10 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm import attention_ops
from vllm import cache_ops
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.rotary_embedding import (
DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding,
RotaryEmbedding)
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
@ -247,7 +249,7 @@ class PagedAttention(nn.Module):
class PagedAttentionWithRoPE(PagedAttention):
"""PagedAttention with rotary embedding."""
"""PagedAttention with rotary positional embedding."""
def __init__(
self,
@ -259,34 +261,26 @@ class PagedAttentionWithRoPE(PagedAttention):
base: int = 10000,
num_kv_heads: Optional[int] = None,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads)
self.is_neox_style = is_neox_style
# Create the cos and sin cache.
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim))
t = torch.arange(max_position, dtype=torch.float, device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim]
self.register_buffer("cos_sin_cache", cache, persistent=False)
if rope_scaling is None:
self.rotary_emb = RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LinearScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
@ -303,7 +297,7 @@ class PagedAttentionWithRoPE(PagedAttention):
Args:
positions: shape = [num_tokens]
query: shape = [num_tokens, num_heads * head_size]
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
@ -319,14 +313,7 @@ class PagedAttentionWithRoPE(PagedAttention):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
pos_encoding_ops.rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
query, key = self.rotary_emb(positions, query, key)
return super().forward(
query,
key,

View File

@ -0,0 +1,169 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
from typing import Tuple, Union
import torch
import torch.nn as nn
from vllm import pos_encoding_ops
class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
cache = self._compute_cos_sin_cache()
cache = cache.to(torch.get_default_dtype())
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings,
dtype=torch.float,
device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# pos_encoding_ops.rotary_embedding() is an in-place operation that
# updates the query and key tensors.
pos_encoding_ops.rotary_embedding(positions, query, key,
self.head_size, self.cos_sin_cache,
self.is_neox_style)
return query, key
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.
Credits to the Reddit user /u/kaiokendev
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="cuda")
t = t / self.scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
base = self.base * (
(self.scaling_factor * max_len / self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float, device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache

View File

@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
@ -92,6 +92,7 @@ class LlamaAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
@ -135,7 +136,8 @@ class LlamaAttention(nn.Module):
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling)
def forward(
self,
@ -165,6 +167,7 @@ class LlamaDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = LlamaAttention(
@ -172,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)