diff --git a/vllm/config.py b/vllm/config.py index 0e3f4ac87d6c..328ba94ca4a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -351,6 +351,17 @@ def _get_and_verify_max_len( if max_len_key is not None: derived_max_model_len = min(derived_max_model_len, max_len_key) + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + if derived_max_model_len == float("inf"): + raise ValueError( + "When using rope_scaling, the model's config.json must " + "contain one of the following keys to determine the original " + f"maximum length of the model: {possible_keys}") + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + derived_max_model_len *= scaling_factor + if max_model_len is None: max_model_len = derived_max_model_len elif max_model_len > derived_max_model_len: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index a60f7b7bae24..ccae52120b69 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -1,5 +1,5 @@ """Multi-head attention.""" -from typing import List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn as nn @@ -9,8 +9,10 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, from vllm import attention_ops from vllm import cache_ops -from vllm import pos_encoding_ops from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.rotary_embedding import ( + DynamicNTKScalingRotaryEmbedding, LinearScalingRotaryEmbedding, + RotaryEmbedding) _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] @@ -247,7 +249,7 @@ class PagedAttention(nn.Module): class PagedAttentionWithRoPE(PagedAttention): - """PagedAttention with rotary embedding.""" + """PagedAttention with rotary positional embedding.""" def __init__( self, @@ -259,34 +261,26 @@ class PagedAttentionWithRoPE(PagedAttention): base: int = 10000, num_kv_heads: Optional[int] = None, is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, ) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads) - self.is_neox_style = is_neox_style - - # Create the cos and sin cache. - # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. - # However, we use `torch.arange(..., dtype=torch.float)` instead to - # avoid numerical issues with large base values (e.g., 10000000). - # This may cause a slight numerical difference between the HF - # implementation and ours. - # NOTE(woosuk): To exactly match the HF implementation, we need to - # use CPU to compute the cache and then move it to GPU. However, we - # create the cache on GPU for faster initialization. This may cause - # a slight numerical difference between the HF implementation and ours. - inv_freq = 1.0 / (base**(torch.arange( - 0, rotary_dim, 2, dtype=torch.float, device="cuda") / rotary_dim)) - t = torch.arange(max_position, dtype=torch.float, device="cuda") - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - - # FIXME(woosuk): This assumes that we configure the default dtype when - # initializing the model. - torch_dtype = torch.get_default_dtype() - cache = cache.to(torch_dtype) - # Embedding size: [max_position, rotary_dim] - self.register_buffer("cos_sin_cache", cache, persistent=False) + if rope_scaling is None: + self.rotary_emb = RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style) + else: + scaling_type = rope_scaling["type"] + scaling_factor = rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LinearScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor) + elif scaling_type == "dynamic": + self.rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def forward( self, @@ -303,7 +297,7 @@ class PagedAttentionWithRoPE(PagedAttention): Args: positions: shape = [num_tokens] - query: shape = [num_tokens, num_heads * head_size] + query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, @@ -319,14 +313,7 @@ class PagedAttentionWithRoPE(PagedAttention): # Apply rotary embedding to the query and key before passing them # to the attention op. - pos_encoding_ops.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) + query, key = self.rotary_emb(positions, query, key) return super().forward( query, key, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py new file mode 100644 index 000000000000..4ecde07562fa --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings.""" +from typing import Tuple, Union + +import torch +import torch.nn as nn + +from vllm import pos_encoding_ops + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + + cache = self._compute_cos_sin_cache() + cache = cache.to(torch.get_default_dtype()) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, + dtype=torch.float, + device="cuda") + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # pos_encoding_ops.rotary_embedding() is an in-place operation that + # updates the query and key tensors. + pos_encoding_ops.rotary_embedding(positions, query, key, + self.head_size, self.cos_sin_cache, + self.is_neox_style) + return query, key + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + t = torch.arange(max_len, dtype=torch.float, device="cuda") + t = t / self.scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float, device="cuda") + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 79b7bed26727..d3b5f7283e6f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -25,7 +25,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch from torch import nn @@ -92,6 +92,7 @@ class LlamaAttention(nn.Module): num_heads: int, num_kv_heads: int, rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -135,7 +136,8 @@ class LlamaAttention(nn.Module): base=self.rope_theta, max_position=self.max_position_embeddings, rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads) + num_kv_heads=self.num_kv_heads, + rope_scaling=rope_scaling) def forward( self, @@ -165,6 +167,7 @@ class LlamaDecoderLayer(nn.Module): self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( @@ -172,6 +175,7 @@ class LlamaDecoderLayer(nn.Module): num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, + rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, quant_config=quant_config, )