mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 04:45:01 +08:00
Refactor Attention (#1840)
This commit is contained in:
parent
0229c386c5
commit
a9e4574261
@ -1,5 +1,5 @@
|
||||
"""Multi-head attention."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -10,7 +10,6 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
from vllm._C import ops
|
||||
from vllm._C import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
@ -18,37 +17,39 @@ _PARTITION_SIZE = 512
|
||||
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
"""GPT-style multi-head PagedAttention.
|
||||
"""MHA/MQA/GQA layer with PagedAttention.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens, in addition to
|
||||
paddings.
|
||||
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
1. Perform multi_query_kv_attention for the prompts. This operation does
|
||||
not use the KV cache.
|
||||
2. Wait for the cache operations (e.g., swap, copy) to finish. The cache
|
||||
|
||||
1. Wait for the cache operations (e.g., swap, copy) to finish. The cache
|
||||
operations are issued by the cache engine before executing the forward
|
||||
pass of the model, and they are executed asynchronously.
|
||||
3. Reshape and store the input key and value tensors in the KV cache.
|
||||
4. Perform single_query_cached_kv_attention for the generation tokens.
|
||||
This operation reads the previous key and value tensors from the KV
|
||||
cache.
|
||||
5. Return the output tensor.
|
||||
2. Reshape and store the input key and value tensors in the KV cache.
|
||||
3. Perform (multi-head/multi-query/grouped-query) attention using either
|
||||
xformers or the PagedAttention custom op.
|
||||
4. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -60,153 +61,6 @@ class PagedAttention(nn.Module):
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
def set_attn_bias(
|
||||
self,
|
||||
input_metadata: InputMetadata,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
del dtype # Unused.
|
||||
if input_metadata.attn_bias is not None:
|
||||
# Already set by a previous layer.
|
||||
return
|
||||
prompt_lens = [input_metadata.max_prompt_len
|
||||
] * input_metadata.num_prompts
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Normal attention for the prompt tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Project the key and value tensors to the desired number of heads.
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
key.unsqueeze(0),
|
||||
value.unsqueeze(0),
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.view_as(output))
|
||||
return output
|
||||
|
||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||
"""Returns the slopes for the alibi attention bias.
|
||||
|
||||
Returns:
|
||||
slopes: shape = [num_heads]
|
||||
"""
|
||||
return None
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> None:
|
||||
"""PagedAttention for the generation tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_generation_tokens, num_heads, head_size]
|
||||
query: shape = [num_generation_tokens, num_heads, head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
alibi_slopes: shape = [num_heads]
|
||||
"""
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (
|
||||
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = input_metadata.max_context_len <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@ -219,9 +73,6 @@ class PagedAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
"""PagedAttention forward pass.
|
||||
|
||||
NOTE: The query, key, and value tensors must be sliced from a qkv
|
||||
tensor of shape [batch_size, seq_len, 3 * num_heads * head_size].
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
@ -230,46 +81,28 @@ class PagedAttention(nn.Module):
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
input_metadata: metadata for the inputs.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
batch_size, seq_len, _ = query.shape
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
slot_mapping = input_metadata.slot_mapping.flatten()
|
||||
|
||||
# Pre-allocate the output tensor.
|
||||
output = torch.empty_like(query)
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||
if num_prompt_tokens > 0:
|
||||
# Prompt run.
|
||||
assert input_metadata.num_generation_tokens == 0
|
||||
self.set_attn_bias(input_metadata, dtype=query.dtype)
|
||||
self.multi_query_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
input_metadata,
|
||||
)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
if cache_event is not None:
|
||||
cache_event.wait()
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# When key_cache and value_cache are not provided, the new key
|
||||
# and value vectors will not be cached.
|
||||
# If key_cache and value_cache are not provided, the new key and value
|
||||
# vectors will not be cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
key_to_cache = key
|
||||
value_to_cache = value
|
||||
slot_mapping = input_metadata.slot_mapping.view(-1)
|
||||
if input_metadata.to_cache is not None:
|
||||
key_to_cache = key_to_cache[input_metadata.to_cache]
|
||||
value_to_cache = value_to_cache[input_metadata.to_cache]
|
||||
@ -283,178 +116,175 @@ class PagedAttention(nn.Module):
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
if input_metadata.num_generation_tokens > 0:
|
||||
is_prompt = len(input_metadata.prompt_lens) > 0
|
||||
if is_prompt:
|
||||
# Prompt run.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
key.shape[-1])
|
||||
value = value[:, :, None, :].expand(value.shape[0],
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# Set attention bias if not provided. This typically happens at the
|
||||
# very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if input_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
[seq_len] * batch_size)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, batch_size, seq_len, query.dtype)
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce them
|
||||
# in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
else:
|
||||
query = query.unflatten(0, (batch_size, seq_len))
|
||||
key = key.unflatten(0, (batch_size, seq_len))
|
||||
value = value.unflatten(0, (batch_size, seq_len))
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
# Decoding run.
|
||||
assert input_metadata.num_prompt_tokens == 0
|
||||
assert key_cache is not None and value_cache is not None, (
|
||||
"key_cache and value_cache must be provided when "
|
||||
"generating tokens.")
|
||||
# Compute the attention op for generation tokens.
|
||||
self.single_query_cached_kv_attention(output, query, key_cache,
|
||||
value_cache, input_metadata,
|
||||
self.get_alibi_slopes())
|
||||
output = _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.head_mapping,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
# NOTE(woosuk): The output tensor may include paddings.
|
||||
return output.view(batch_size, seq_len,
|
||||
self.num_heads * self.head_size)
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
|
||||
|
||||
class PagedAttentionWithRoPE(PagedAttention):
|
||||
"""PagedAttention with rotary positional embedding."""
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
) -> LowerTriangularMaskWithTensorBias:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
bias = bias.to(alibi_slopes.device)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
rotary_dim: int,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
sliding_window=sliding_window)
|
||||
self.rotary_emb = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, rope_scaling)
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
bias = torch.empty(
|
||||
batch_size,
|
||||
alibi_slopes.shape[0],
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
return attn_bias
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
""" PagedAttention forward pass with rotary embedding.
|
||||
|
||||
Args:
|
||||
positions: shape = [batch_size, seq_len]
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
def _paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
head_mapping: torch.Tensor,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
|
||||
# Apply rotary embedding to the query and key before passing them
|
||||
# to the attention op.
|
||||
query, key = self.rotary_emb(positions, query, key)
|
||||
return super().forward(
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (
|
||||
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
|
||||
use_v1 = input_metadata.max_context_len <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
cache_event,
|
||||
head_mapping,
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
|
||||
class PagedAttentionWithALiBi(PagedAttention):
|
||||
"""PagedAttention with ALiBi attention bias."""
|
||||
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
slopes: List[float],
|
||||
num_kv_heads: Optional[int] = None) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads)
|
||||
assert len(slopes) == num_heads
|
||||
|
||||
slopes = torch.tensor(slopes, dtype=torch.float32)
|
||||
self.register_buffer("alibi_slopes", slopes, persistent=False)
|
||||
|
||||
def set_attn_bias(self, input_metadata: InputMetadata,
|
||||
dtype: torch.dtype) -> None:
|
||||
if input_metadata.attn_bias is not None:
|
||||
# Already set by a previous layer.
|
||||
return
|
||||
# Generates ALiBi mask based on the max prompt length.
|
||||
max_prompt_len = input_metadata.max_prompt_len
|
||||
bias = torch.arange(max_prompt_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
bias = bias.to(self.alibi_slopes.device)
|
||||
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (max_prompt_len + 7) // 8 * 8
|
||||
bias = torch.empty(
|
||||
input_metadata.num_prompts,
|
||||
self.num_heads,
|
||||
max_prompt_len,
|
||||
padded_len,
|
||||
device=self.alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :max_prompt_len].copy_(bias)
|
||||
bias.mul_(self.alibi_slopes[:, None, None])
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
|
||||
def multi_query_kv_attention(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Attention with ALiBi bias for the prompt tokens.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
input_metadata: metadata for paged attention.
|
||||
"""
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# Project the key and value tensors to the desired number of heads.
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
batch_size = input_metadata.num_prompts
|
||||
seq_len = input_metadata.max_prompt_len
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||
key.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||
value.view(batch_size, seq_len, self.num_heads, self.head_size),
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output.copy_(out.view_as(output))
|
||||
return output
|
||||
|
||||
def get_alibi_slopes(self) -> Optional[torch.Tensor]:
|
||||
return self.alibi_slopes
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -277,8 +277,8 @@ def get_rope(
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
rope_scaling: Optional[Dict[str, Any]],
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
|
||||
@ -28,11 +28,12 @@ from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -138,15 +139,17 @@ class AquilaAttention(nn.Module):
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
rope_scaling=rope_scaling)
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -158,9 +161,10 @@ class AquilaAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -26,13 +26,13 @@ from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||
PagedAttentionWithALiBi)
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -150,17 +150,20 @@ class BaiChuanAttention(nn.Module):
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||
scaling, alibi_slopes)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes)
|
||||
else:
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings)
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -172,14 +175,11 @@ class BaiChuanAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.W_pack(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
if self.postion_embedding != "ALIBI":
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
if self.postion_embedding == "ALIBI":
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
else:
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from transformers import BloomConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -106,8 +106,10 @@ class BloomAttention(nn.Module):
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||
scaling, alibi_slopes)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -10,12 +10,13 @@ from torch.nn import LayerNorm
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -78,16 +79,19 @@ class GLMAttention(nn.Module):
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
||||
max_positions = getattr(config, "seq_length", 8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim // 2,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
max_position=max_positions,
|
||||
base=10000 * rope_ratio,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = PagedAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -99,10 +103,9 @@ class GLMAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
key_cache, value_cache = kv_cache
|
||||
|
||||
context_layer = self.attn(
|
||||
position_ids,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -111,9 +114,7 @@ class GLMAttention(nn.Module):
|
||||
input_metadata,
|
||||
cache_event,
|
||||
)
|
||||
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
|
||||
@ -28,13 +28,12 @@ from transformers import FalconConfig as HF_FalconConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import (PagedAttention,
|
||||
PagedAttentionWithALiBi,
|
||||
PagedAttentionWithRoPE)
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -144,14 +143,16 @@ class FalconAttention(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config,
|
||||
"max_position_embeddings", 8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
elif self.use_alibi:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
@ -159,11 +160,11 @@ class FalconAttention(nn.Module):
|
||||
alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
|
||||
self.inv_norm_factor)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
alibi_slopes,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
alibi_slopes=alibi_slopes)
|
||||
else:
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -182,13 +183,11 @@ class FalconAttention(nn.Module):
|
||||
if bias is not None:
|
||||
qkv += bias
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
if self.use_rotary:
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
else:
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output, bias = self.dense(attn_output)
|
||||
return attn_output, bias
|
||||
|
||||
|
||||
@ -24,11 +24,12 @@ from transformers import GPTJConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -77,15 +78,14 @@ class GPTJAttention(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_size,
|
||||
scaling,
|
||||
config.rotary_dim,
|
||||
base=rope_theta,
|
||||
rotary_dim=config.rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
is_neox_style=False)
|
||||
self.warmup = False
|
||||
base=rope_theta,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -97,9 +97,10 @@ class GPTJAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
@ -24,11 +24,12 @@ from transformers import GPTNeoXConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -77,13 +78,13 @@ class GPTNeoXAttention(nn.Module):
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_size,
|
||||
scaling,
|
||||
rotary_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings)
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -95,9 +96,10 @@ class GPTNeoXAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.dense(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -7,12 +7,13 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -92,13 +93,13 @@ class InternLMAttention(nn.Module):
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rotary_dim=self.head_dim)
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -110,9 +111,10 @@ class InternLMAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -29,12 +29,13 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -126,15 +127,18 @@ class LlamaAttention(nn.Module):
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
rope_scaling=rope_scaling)
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -146,9 +150,10 @@ class LlamaAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -29,12 +29,13 @@ from transformers import MistralConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -124,14 +125,18 @@ class MistralAttention(nn.Module):
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
max_position=max_position,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -143,9 +148,10 @@ class MistralAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
@ -87,8 +87,10 @@ class MPTAttention(nn.Module):
|
||||
|
||||
self.head_dim = self.d_model // self.total_num_heads
|
||||
scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
|
||||
scaling, alibi_slopes)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -43,11 +43,12 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -119,13 +120,13 @@ class PhiAttention(nn.Module):
|
||||
# https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
|
||||
rope_theta = 10000
|
||||
max_position_embeddings = getattr(config, "n_positions", 2048)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_size,
|
||||
scaling,
|
||||
rotary_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings)
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -137,9 +138,10 @@ class PhiAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -11,12 +11,13 @@ from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -95,14 +96,15 @@ class QWenAttention(nn.Module):
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings,
|
||||
rope_scaling=rope_scaling)
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -114,10 +116,10 @@ class QWenAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
|
||||
output, _ = self.c_proj(attn_output)
|
||||
return output
|
||||
|
||||
@ -29,12 +29,13 @@ from vllm.transformers_utils.configs.yi import YiConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
@ -126,15 +127,17 @@ class YiAttention(nn.Module):
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
rope_scaling=rope_scaling)
|
||||
max_position=max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -146,9 +149,10 @@ class YiAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
input_metadata, cache_event)
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user