[Misc] Pass attention to impl backend (#12218)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-01-20 23:25:28 +08:00 committed by GitHub
parent 5f0ec3935a
commit 86bfb6dba7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 86 additions and 78 deletions

View File

@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
Tuple, Type, TypeVar)
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
Protocol, Set, Tuple, Type, TypeVar)
import torch
@ -223,6 +223,22 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
raise NotImplementedError
class AttentionLayer(Protocol):
_k_scale: float
_v_scale: float
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
@ -244,13 +260,12 @@ class AttentionImpl(ABC, Generic[T]):
@abstractmethod
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
@ -358,13 +359,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -401,8 +401,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
if prefill_meta := attn_metadata.prefill_metadata:
@ -439,8 +439,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride,

View File

@ -8,6 +8,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
@ -634,13 +635,12 @@ class FlashAttentionImpl(AttentionImpl):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -657,7 +657,7 @@ class FlashAttentionImpl(AttentionImpl):
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
@ -709,8 +709,8 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
(num_prefill_query_tokens, num_prefill_kv_tokens,

View File

@ -23,6 +23,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
@ -792,13 +793,12 @@ class FlashInferImpl(AttentionImpl):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
@ -826,8 +826,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
@ -886,8 +886,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache,
logits_soft_cap=logits_soft_cap,
causal=True,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None
@ -897,8 +897,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache,
sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
window_left=window_left)
if prefill_output is None and decode_output is not None:

View File

@ -11,6 +11,7 @@ import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
@ -152,13 +153,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.

View File

@ -7,6 +7,7 @@ import torch
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
@ -171,13 +172,12 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
@ -210,8 +210,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
if attn_metadata.is_prompt:
@ -296,8 +296,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
# Run PagedAttention V2.
@ -329,8 +329,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# Reshape the output tensor.

View File

@ -5,6 +5,7 @@ import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
@ -150,13 +151,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)

View File

@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
@ -414,13 +415,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -458,8 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
num_prefill_tokens = attn_metadata.num_prefill_tokens
@ -567,8 +567,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
if decode_meta := attn_metadata.decode_metadata:
@ -613,8 +613,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
@ -628,8 +628,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# Reshape the output tensor.

View File

@ -7,6 +7,7 @@ import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
@ -429,13 +430,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@ -451,7 +451,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert k_scale == 1.0 and v_scale == 1.0
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
@ -493,11 +493,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)
PagedAttention.write_to_paged_cache(
key, value, key_cache, value_cache, updated_slot_mapping,
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
@ -571,8 +569,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# Reshape the output tensor.

View File

@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (
CommonAttentionState, CommonMetadataBuilder,
@ -412,13 +413,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: Optional[torch.Tensor],
value: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -524,11 +524,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
updated_slot_mapping,
self.kv_cache_dtype,
k_scale, v_scale)
PagedAttention.write_to_paged_cache(
key, value, key_cache, value_cache, updated_slot_mapping,
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
(num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
@ -580,8 +578,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
assert output[:num_prefill_query_tokens].shape == out.shape
output[:num_prefill_query_tokens] = out
@ -607,8 +605,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# Reshape the output tensor.

View File

@ -243,8 +243,7 @@ def unified_attention(
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
self._k_scale, self._v_scale)
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
def unified_attention_fake(
@ -276,13 +275,12 @@ def unified_attention_with_output(
attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(query,
self.impl.forward(self,
query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
output=output)

View File

@ -130,13 +130,12 @@ class FlashAttentionImpl(AttentionImpl):
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -151,7 +150,7 @@ class FlashAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
@ -183,8 +182,8 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# Compute attention and update output up to `num_actual_tokens`.