From a9e4574261a20d4ada213d26671da7dc7633580b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 29 Nov 2023 15:37:31 -0800 Subject: [PATCH] Refactor Attention (#1840) --- vllm/model_executor/layers/attention.py | 540 ++++++------------ .../model_executor/layers/rotary_embedding.py | 4 +- vllm/model_executor/models/aquila.py | 24 +- vllm/model_executor/models/baichuan.py | 32 +- vllm/model_executor/models/bloom.py | 8 +- vllm/model_executor/models/chatglm.py | 19 +- vllm/model_executor/models/falcon.py | 39 +- vllm/model_executor/models/gpt_j.py | 21 +- vllm/model_executor/models/gpt_neox.py | 18 +- vllm/model_executor/models/internlm.py | 18 +- vllm/model_executor/models/llama.py | 25 +- vllm/model_executor/models/mistral.py | 28 +- vllm/model_executor/models/mpt.py | 8 +- vllm/model_executor/models/phi_1_5.py | 18 +- vllm/model_executor/models/qwen.py | 20 +- vllm/model_executor/models/yi.py | 24 +- 16 files changed, 354 insertions(+), 492 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 63271ba5b9327..55b48fc5c7cca 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -1,5 +1,5 @@ """Multi-head attention.""" -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn as nn @@ -10,7 +10,6 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, from vllm._C import ops from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.rotary_embedding import get_rope _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -18,37 +17,39 @@ _PARTITION_SIZE = 512 class PagedAttention(nn.Module): - """GPT-style multi-head PagedAttention. + """MHA/MQA/GQA layer with PagedAttention. This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens, in addition to - paddings. - + can either contain prompt tokens or generation tokens. The class does the following: - 1. Perform multi_query_kv_attention for the prompts. This operation does - not use the KV cache. - 2. Wait for the cache operations (e.g., swap, copy) to finish. The cache + + 1. Wait for the cache operations (e.g., swap, copy) to finish. The cache operations are issued by the cache engine before executing the forward pass of the model, and they are executed asynchronously. - 3. Reshape and store the input key and value tensors in the KV cache. - 4. Perform single_query_cached_kv_attention for the generation tokens. - This operation reads the previous key and value tensors from the KV - cache. - 5. Return the output tensor. + 2. Reshape and store the input key and value tensors in the KV cache. + 3. Perform (multi-head/multi-query/grouped-query) attention using either + xformers or the PagedAttention custom op. + 4. Return the output tensor. """ - def __init__(self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - sliding_window: Optional[int] = None) -> None: + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: super().__init__() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.sliding_window = sliding_window + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -60,153 +61,6 @@ class PagedAttention(nn.Module): raise ValueError(f"head_size ({self.head_size}) is not supported. " f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") - def set_attn_bias( - self, - input_metadata: InputMetadata, - dtype: torch.dtype, - ) -> None: - del dtype # Unused. - if input_metadata.attn_bias is not None: - # Already set by a previous layer. - return - prompt_lens = [input_metadata.max_prompt_len - ] * input_metadata.num_prompts - attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) - if self.sliding_window is not None: - attn_bias = attn_bias.make_local_attention(self.sliding_window) - input_metadata.attn_bias = attn_bias - - def multi_query_kv_attention( - self, - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - input_metadata: InputMetadata, - ) -> torch.Tensor: - """Normal attention for the prompt tokens. - - Args: - output: shape = [num_prompt_tokens, num_heads, head_size] - query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_kv_heads, head_size] - value: shape = [num_prompt_tokens, num_kv_heads, head_size] - input_metadata: metadata for paged attention. - """ - if self.num_kv_heads != self.num_heads: - # Project the key and value tensors to the desired number of heads. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - - # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. - out = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, - ) - # TODO(woosuk): Unnecessary copy. Optimize. - output.copy_(out.view_as(output)) - return output - - def get_alibi_slopes(self) -> Optional[torch.Tensor]: - """Returns the slopes for the alibi attention bias. - - Returns: - slopes: shape = [num_heads] - """ - return None - - def single_query_cached_kv_attention( - self, - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - alibi_slopes: Optional[torch.Tensor], - ) -> None: - """PagedAttention for the generation tokens. - - Args: - output: shape = [num_generation_tokens, num_heads, head_size] - query: shape = [num_generation_tokens, num_heads, head_size] - key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] - value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] - input_metadata: metadata for paged attention. - alibi_slopes: shape = [num_heads] - """ - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = ( - (input_metadata.max_context_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = input_metadata.max_context_len <= 8192 and ( - max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: - # Run PagedAttention V1. - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - self.head_mapping, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - self.head_mapping, - self.scale, - input_metadata.block_tables, - input_metadata.context_lens, - block_size, - input_metadata.max_context_len, - alibi_slopes, - ) - def forward( self, query: torch.Tensor, @@ -219,9 +73,6 @@ class PagedAttention(nn.Module): ) -> torch.Tensor: """PagedAttention forward pass. - NOTE: The query, key, and value tensors must be sliced from a qkv - tensor of shape [batch_size, seq_len, 3 * num_heads * head_size]. - Args: query: shape = [batch_size, seq_len, num_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size] @@ -230,46 +81,28 @@ class PagedAttention(nn.Module): block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] - input_metadata: metadata for paged attention. + input_metadata: metadata for the inputs. cache_event: event to wait for the cache operations to finish. - Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - batch_size, seq_len, _ = query.shape + batch_size, seq_len, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) + slot_mapping = input_metadata.slot_mapping.flatten() - # Pre-allocate the output tensor. - output = torch.empty_like(query) - - # Compute the attention op for prompts. - num_prompt_tokens = input_metadata.num_prompt_tokens - if num_prompt_tokens > 0: - # Prompt run. - assert input_metadata.num_generation_tokens == 0 - self.set_attn_bias(input_metadata, dtype=query.dtype) - self.multi_query_kv_attention( - output, - query, - key, - value, - input_metadata, - ) - - # Wait until the cache op is done. if cache_event is not None: cache_event.wait() # Reshape the keys and values and store them in the cache. - # When key_cache and value_cache are not provided, the new key - # and value vectors will not be cached. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. if key_cache is not None and value_cache is not None: key_to_cache = key value_to_cache = value - slot_mapping = input_metadata.slot_mapping.view(-1) if input_metadata.to_cache is not None: key_to_cache = key_to_cache[input_metadata.to_cache] value_to_cache = value_to_cache[input_metadata.to_cache] @@ -283,178 +116,175 @@ class PagedAttention(nn.Module): slot_mapping, ) - if input_metadata.num_generation_tokens > 0: + is_prompt = len(input_metadata.prompt_lens) > 0 + if is_prompt: + # Prompt run. + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at the + # very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, batch_size, seq_len, query.dtype) + + # TODO(woosuk): Too many view operations. Let's try to reduce them + # in the future for code readability. + if self.alibi_slopes is None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + else: + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) + + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + ) + output = out.view_as(query) + else: # Decoding run. - assert input_metadata.num_prompt_tokens == 0 - assert key_cache is not None and value_cache is not None, ( - "key_cache and value_cache must be provided when " - "generating tokens.") - # Compute the attention op for generation tokens. - self.single_query_cached_kv_attention(output, query, key_cache, - value_cache, input_metadata, - self.get_alibi_slopes()) + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.head_mapping, + self.scale, + self.alibi_slopes, + ) # Reshape the output tensor. - # NOTE(woosuk): The output tensor may include paddings. - return output.view(batch_size, seq_len, - self.num_heads * self.head_size) + return output.view(batch_size, seq_len, hidden_size) -class PagedAttentionWithRoPE(PagedAttention): - """PagedAttention with rotary positional embedding.""" +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + batch_size: int, + seq_len: int, + dtype: torch.dtype, +) -> LowerTriangularMaskWithTensorBias: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + bias = bias.to(alibi_slopes.device) - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - rotary_dim: int, - max_position: int = 8192, - base: int = 10000, - num_kv_heads: Optional[int] = None, - is_neox_style: bool = True, - rope_scaling: Optional[Dict[str, Any]] = None, - sliding_window: Optional[int] = None, - ) -> None: - super().__init__(num_heads, - head_size, - scale, - num_kv_heads, - sliding_window=sliding_window) - self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base, - is_neox_style, rope_scaling) + # When using custom attention bias, xformers requires the bias to + # be sliced from a tensor whose length is a multiple of 8. + padded_len = (seq_len + 7) // 8 * 8 + bias = torch.empty( + batch_size, + alibi_slopes.shape[0], + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + return attn_bias - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], - ) -> torch.Tensor: - """ PagedAttention forward pass with rotary embedding. - Args: - positions: shape = [batch_size, seq_len] - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - key_cache: shape = [num_blocks, num_kv_heads, head_size/x, - block_size, x] - value_cache: shape = [num_blocks, num_kv_heads, head_size, - block_size] - input_metadata: metadata for paged attention. - cache_event: event to wait for the cache operations to finish. +def _paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + head_mapping: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], +) -> torch.Tensor: + output = torch.empty_like(query) - Returns: - shape = [batch_size, seq_len, num_heads * head_size] - """ - - # Apply rotary embedding to the query and key before passing them - # to the attention op. - query, key = self.rotary_emb(positions, query, key) - return super().forward( + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + (input_metadata.max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = input_metadata.max_context_len <= 8192 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512) + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, query, - key, - value, key_cache, value_cache, - input_metadata, - cache_event, + head_mapping, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, ) - - -class PagedAttentionWithALiBi(PagedAttention): - """PagedAttention with ALiBi attention bias.""" - - def __init__(self, - num_heads: int, - head_size: int, - scale: float, - slopes: List[float], - num_kv_heads: Optional[int] = None) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads) - assert len(slopes) == num_heads - - slopes = torch.tensor(slopes, dtype=torch.float32) - self.register_buffer("alibi_slopes", slopes, persistent=False) - - def set_attn_bias(self, input_metadata: InputMetadata, - dtype: torch.dtype) -> None: - if input_metadata.attn_bias is not None: - # Already set by a previous layer. - return - # Generates ALiBi mask based on the max prompt length. - max_prompt_len = input_metadata.max_prompt_len - bias = torch.arange(max_prompt_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - bias = bias.to(self.alibi_slopes.device) - - # When using custom attention bias, xformers requires the bias to - # be sliced from a tensor whose length is a multiple of 8. - padded_len = (max_prompt_len + 7) // 8 * 8 - bias = torch.empty( - input_metadata.num_prompts, - self.num_heads, - max_prompt_len, - padded_len, - device=self.alibi_slopes.device, - dtype=dtype, - )[:, :, :, :max_prompt_len].copy_(bias) - bias.mul_(self.alibi_slopes[:, None, None]) - attn_bias = LowerTriangularMaskWithTensorBias(bias) - input_metadata.attn_bias = attn_bias - - def multi_query_kv_attention( - self, - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - input_metadata: InputMetadata, - ) -> torch.Tensor: - """Attention with ALiBi bias for the prompt tokens. - - Args: - output: shape = [num_prompt_tokens, num_heads, head_size] - query: shape = [num_prompt_tokens, num_heads, head_size] - key: shape = [num_prompt_tokens, num_kv_heads, head_size] - value: shape = [num_prompt_tokens, num_kv_heads, head_size] - input_metadata: metadata for paged attention. - """ - if self.num_kv_heads != self.num_heads: - # Project the key and value tensors to the desired number of heads. - query = query.view(query.shape[0], self.num_kv_heads, - self.num_queries_per_kv, query.shape[-1]) - key = key[:, :, - None, :].expand(key.shape[0], self.num_kv_heads, - self.num_queries_per_kv, key.shape[-1]) - value = value[:, :, - None, :].expand(value.shape[0], self.num_kv_heads, - self.num_queries_per_kv, - value.shape[-1]) - batch_size = input_metadata.num_prompts - seq_len = input_metadata.max_prompt_len - - out = xops.memory_efficient_attention_forward( - query.view(batch_size, seq_len, self.num_heads, self.head_size), - key.view(batch_size, seq_len, self.num_heads, self.head_size), - value.view(batch_size, seq_len, self.num_heads, self.head_size), - attn_bias=input_metadata.attn_bias, - p=0.0, - scale=self.scale, + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, ) - # TODO(woosuk): Unnecessary copy. Optimize. - output.copy_(out.view_as(output)) - return output - - def get_alibi_slopes(self) -> Optional[torch.Tensor]: - return self.alibi_slopes + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + ) + return output diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 162bb0b533e4f..0bde4cefbb99c 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -277,8 +277,8 @@ def get_rope( rotary_dim: int, max_position: int, base: int, - is_neox_style: bool, - rope_scaling: Optional[Dict[str, Any]], + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, ) -> RotaryEmbedding: if rope_scaling is None: rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 889239cdb4e0e..ba2af445b1364 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -28,11 +28,12 @@ from torch import nn from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -138,15 +139,17 @@ class AquilaAttention(nn.Module): bias=False, linear_method=linear_method, ) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_dim, - self.scaling, - base=self.rope_theta, - max_position=self.max_position_embeddings, rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads, - rope_scaling=rope_scaling) + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, @@ -158,9 +161,10 @@ class AquilaAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 61cc2192b01bb..e1de6bbefbbc6 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -26,13 +26,13 @@ from torch import nn from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, - PagedAttentionWithALiBi) +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module): alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, - scaling, alibi_slopes) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) else: - self.scaling = self.head_dim**-0.5 - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_dim, - self.scaling, rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, base=self.rope_theta, - max_position=self.max_position_embeddings) + ) + self.scaling = self.head_dim**-0.5 + self.attn = PagedAttention(self.num_heads, self.head_dim, + self.scaling) def forward( self, @@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.postion_embedding != "ALIBI": + q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - if self.postion_embedding == "ALIBI": - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) - else: - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) - + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 99ccd7442f31b..1703d1cdb3670 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -25,7 +25,7 @@ from transformers import BloomConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttentionWithALiBi +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -106,8 +106,10 @@ class BloomAttention(nn.Module): alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 - self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, - scaling, alibi_slopes) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) def forward( self, diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 5d243168a41f3..5c08a1a823685 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -10,12 +10,13 @@ from torch.nn import LayerNorm from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -78,16 +79,19 @@ class GLMAttention(nn.Module): # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 rope_ratio = getattr(config, "rope_ratio", 1.0) max_positions = getattr(config, "seq_length", 8192) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_dim, - self.scaling, rotary_dim=self.head_dim // 2, - num_kv_heads=self.num_kv_heads, max_position=max_positions, base=10000 * rope_ratio, is_neox_style=False, ) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + ) def forward( self, @@ -99,10 +103,9 @@ class GLMAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) key_cache, value_cache = kv_cache - context_layer = self.attn( - position_ids, q, k, v, @@ -111,9 +114,7 @@ class GLMAttention(nn.Module): input_metadata, cache_event, ) - attn_output, _ = self.dense(context_layer) - return attn_output diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index ceb7c651823e0..b7af514661a68 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import (PagedAttention, - PagedAttentionWithALiBi, - PagedAttentionWithRoPE) +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -144,14 +143,16 @@ class FalconAttention(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_dim, - self.inv_norm_factor, - base=rope_theta, - max_position=max_position_embeddings, rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads) + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads) elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -159,11 +160,11 @@ class FalconAttention(nn.Module): alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * self.inv_norm_factor) alibi_slopes = alibi_slopes[head_start:head_end].tolist() - self.attn = PagedAttentionWithALiBi(self.num_heads, - self.head_dim, - self.inv_norm_factor, - alibi_slopes, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes) else: self.attn = PagedAttention(self.num_heads, self.head_dim, @@ -182,13 +183,11 @@ class FalconAttention(nn.Module): if bias is not None: qkv += bias q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - k_cache, v_cache = kv_cache if self.use_rotary: - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) - else: - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) attn_output, bias = self.dense(attn_output) return attn_output, bias diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 1f0f7d4206c88..7db6edd110f27 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -24,11 +24,12 @@ from transformers import GPTJConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -77,15 +78,14 @@ class GPTJAttention(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_size, - scaling, - config.rotary_dim, - base=rope_theta, + rotary_dim=config.rotary_dim, max_position=max_position_embeddings, - is_neox_style=False) - self.warmup = False + base=rope_theta, + is_neox_style=False, + ) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) def forward( self, @@ -97,9 +97,10 @@ class GPTJAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) attn_output, _ = self.out_proj(attn_output) return attn_output diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index b289ddc51da85..1d21d06e21a62 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_size, - scaling, - rotary_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, base=rope_theta, - max_position=max_position_embeddings) + ) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) def forward( self, @@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.dense(attn_output) return output diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 13b2e70deeb86..8b20462c18c15 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -7,12 +7,13 @@ from transformers import LlamaConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -92,13 +93,13 @@ class InternLMAttention(nn.Module): bias=bias, linear_method=linear_method, ) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_dim, - self.scaling, - base=self.rope_theta, + rotary_dim=self.head_dim, max_position=self.max_position_embeddings, - rotary_dim=self.head_dim) + base=self.rope_theta, + ) + self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) def forward( self, @@ -110,9 +111,10 @@ class InternLMAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8e7344da4888e..cd39d9059eeca 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,12 +29,13 @@ from transformers import LlamaConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -126,15 +127,18 @@ class LlamaAttention(nn.Module): bias=False, linear_method=linear_method, ) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + + self.rotary_emb = get_rope( self.head_dim, - self.scaling, - base=self.rope_theta, - max_position=self.max_position_embeddings, rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads, - rope_scaling=rope_scaling) + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, @@ -146,9 +150,10 @@ class LlamaAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index d18572610741c..8470020006f3c 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -29,12 +29,13 @@ from transformers import MistralConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -124,14 +125,18 @@ class MistralAttention(nn.Module): bias=False, linear_method=linear_method, ) - self.attn = PagedAttentionWithRoPE(self.num_heads, - self.head_dim, - self.scaling, - base=self.rope_theta, - max_position=max_position, - rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads, - sliding_window=self.sliding_window) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window) def forward( self, @@ -143,9 +148,10 @@ class MistralAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 47130649d3c6c..caa169b57d481 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -8,7 +8,7 @@ import torch.nn as nn from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttentionWithALiBi +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, @@ -87,8 +87,10 @@ class MPTAttention(nn.Module): self.head_dim = self.d_model // self.total_num_heads scaling = self.head_dim**-0.5 - self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, - scaling, alibi_slopes) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) def forward( self, diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index 7ef614601da39..bd4afbf2ca973 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -43,11 +43,12 @@ from transformers import PretrainedConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -119,13 +120,13 @@ class PhiAttention(nn.Module): # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 rope_theta = 10000 max_position_embeddings = getattr(config, "n_positions", 2048) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_size, - scaling, - rotary_dim, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, base=rope_theta, - max_position=max_position_embeddings) + ) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) def forward( self, @@ -137,9 +138,10 @@ class PhiAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.Wqkv(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.out_proj(attn_output) return output diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d581838f6ce8f..e6c089b3f8289 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,12 +11,13 @@ from torch import nn from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -95,14 +96,15 @@ class QWenAttention(nn.Module): linear_method=linear_method, ) self.scaling = self.head_dim**-0.5 - self.attn = PagedAttentionWithRoPE( - self.num_heads, + + self.rotary_emb = get_rope( self.head_dim, - self.scaling, rotary_dim=self.head_dim, - base=rope_theta, max_position=max_position_embeddings, - rope_scaling=rope_scaling) + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) def forward( self, @@ -114,10 +116,10 @@ class QWenAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) - + q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.c_proj(attn_output) return output diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index c457132855cdc..af83241412727 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -126,15 +127,17 @@ class YiAttention(nn.Module): bias=False, linear_method=linear_method, ) - self.attn = PagedAttentionWithRoPE( - self.num_heads, + self.rotary_emb = get_rope( self.head_dim, - self.scaling, - base=self.rope_theta, - max_position=self.max_position_embeddings, rotary_dim=self.head_dim, - num_kv_heads=self.num_kv_heads, - rope_scaling=rope_scaling) + max_position=max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) def forward( self, @@ -146,9 +149,10 @@ class YiAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(positions, q, k, v, k_cache, v_cache, - input_metadata, cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, + cache_event) output, _ = self.o_proj(attn_output) return output