[Misc] Pass attention to impl backend (#12218)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan 2025-01-20 23:25:28 +08:00 committed by GitHub
parent 5f0ec3935a
commit 86bfb6dba7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 86 additions and 78 deletions

View File

@ -1,8 +1,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
Tuple, Type, TypeVar) Protocol, Set, Tuple, Type, TypeVar)
import torch import torch
@ -223,6 +223,22 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
raise NotImplementedError raise NotImplementedError
class AttentionLayer(Protocol):
_k_scale: float
_v_scale: float
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
class AttentionImpl(ABC, Generic[T]): class AttentionImpl(ABC, Generic[T]):
@abstractmethod @abstractmethod
@ -244,13 +260,12 @@ class AttentionImpl(ABC, Generic[T]):
@abstractmethod @abstractmethod
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState, from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder) CommonMetadataBuilder)
@ -358,13 +359,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: BlocksparseFlashAttentionMetadata, attn_metadata: BlocksparseFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -401,8 +401,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
@ -439,8 +439,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
blocksparse_local_blocks=self.local_blocks, blocksparse_local_blocks=self.local_blocks,
blocksparse_vert_stride=self.vert_stride, blocksparse_vert_stride=self.vert_stride,

View File

@ -8,6 +8,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType) AttentionType)
@ -634,13 +635,12 @@ class FlashAttentionImpl(AttentionImpl):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -657,7 +657,7 @@ class FlashAttentionImpl(AttentionImpl):
NOTE: It in-place updates the output tensor. NOTE: It in-place updates the output tensor.
""" """
# NOTE(woosuk): FlashAttention does not support FP8 KV cache. # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, ( assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
@ -709,8 +709,8 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache[1], kv_cache[1],
updated_slot_mapping.flatten(), # type: ignore[union-attr] updated_slot_mapping.flatten(), # type: ignore[union-attr]
kv_cache_dtype, kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
(num_prefill_query_tokens, num_prefill_kv_tokens, (num_prefill_query_tokens, num_prefill_kv_tokens,

View File

@ -23,6 +23,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionState, AttentionType) AttentionState, AttentionType)
@ -792,13 +793,12 @@ class FlashInferImpl(AttentionImpl):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashInferMetadata, attn_metadata: FlashInferMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -826,8 +826,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache[:, 1], kv_cache[:, 1],
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
kv_cache_dtype, kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8 # to process the cache when the kv_cache_dtype is fp8
@ -886,8 +886,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache, kv_cache,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
causal=True, causal=True,
k_scale=k_scale, k_scale=layer._k_scale,
v_scale=v_scale, v_scale=layer._v_scale,
window_left=window_left) window_left=window_left)
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None assert decode_meta is not None
@ -897,8 +897,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache, kv_cache,
sm_scale=softmax_scale, sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
k_scale=k_scale, k_scale=layer._k_scale,
v_scale=v_scale, v_scale=layer._v_scale,
window_left=window_left) window_left=window_left)
if prefill_output is None and decode_output is not None: if prefill_output is None and decode_output is not None:

View File

@ -11,6 +11,7 @@ import vllm_hpu_extension.ops as ops
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
@ -152,13 +153,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: HPUAttentionMetadata, attn_metadata: HPUAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.

View File

@ -7,6 +7,7 @@ import torch
from vllm._ipex_ops import ipex_ops from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
@ -171,13 +172,12 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: IpexAttnMetadata, # type: ignore attn_metadata: IpexAttnMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.
@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert k_scale == 1.0 and v_scale == 1.0 assert layer._k_scale == 1.0 and layer._v_scale == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
@ -210,8 +210,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache, value_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
if attn_metadata.is_prompt: if attn_metadata.is_prompt:
@ -296,8 +296,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
@ -329,8 +329,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.

View File

@ -5,6 +5,7 @@ import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops. import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.backends.utils import CommonAttentionState
@ -150,13 +151,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor], kv_cache: Tuple[torch.Tensor, torch.Tensor],
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
assert k_scale == 1.0 and v_scale == 1.0 assert layer._k_scale == 1.0 and layer._v_scale == 1.0
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size) query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)

View File

@ -7,6 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState, from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder) CommonMetadataBuilder)
@ -414,13 +415,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: ROCmFlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -458,8 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
@ -567,8 +567,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta.max_query_len, prefill_meta.max_query_len,
self.alibi_slopes, self.alibi_slopes,
self.sliding_window[0], self.sliding_window[0],
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
@ -613,8 +613,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seq_len, max_seq_len,
self.alibi_slopes, self.alibi_slopes,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
else: else:
output[num_prefill_tokens:] = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
@ -628,8 +628,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.

View File

@ -7,6 +7,7 @@ import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType) AttentionType)
@ -429,13 +430,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: TorchSDPAMetadata, # type: ignore attn_metadata: TorchSDPAMetadata, # type: ignore
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
@ -451,7 +451,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert k_scale == 1.0 and v_scale == 1.0 assert layer._k_scale == 1.0 and layer._v_scale == 1.0
attn_type = self.attn_type attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)): and (not attn_metadata.is_all_encoder_attn_metadata_set)):
@ -493,11 +493,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Update self-attention KV cache (prefill/decode) # Update self-attention KV cache (prefill/decode)
updated_slot_mapping = attn_metadata.slot_mapping updated_slot_mapping = attn_metadata.slot_mapping
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(
value_cache, key, value, key_cache, value_cache, updated_slot_mapping,
updated_slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale)
self.kv_cache_dtype,
k_scale, v_scale)
if attn_type != AttentionType.ENCODER: if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill. # Decoder self-attention supports chunked prefill.
@ -571,8 +569,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.

View File

@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias) LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import ( from vllm.attention.backends.utils import (
CommonAttentionState, CommonMetadataBuilder, CommonAttentionState, CommonMetadataBuilder,
@ -412,13 +413,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
def forward( def forward(
self, self,
layer: AttentionLayer,
query: torch.Tensor, query: torch.Tensor,
key: Optional[torch.Tensor], key: Optional[torch.Tensor],
value: Optional[torch.Tensor], value: Optional[torch.Tensor],
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: "XFormersMetadata", attn_metadata: "XFormersMetadata",
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
@ -524,11 +524,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# If kv_cache is not provided, the new key and value tensors are # If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory # not cached. This happens during the initial memory
# profiling run. # profiling run.
PagedAttention.write_to_paged_cache(key, value, key_cache, PagedAttention.write_to_paged_cache(
value_cache, key, value, key_cache, value_cache, updated_slot_mapping,
updated_slot_mapping, self.kv_cache_dtype, layer._k_scale, layer._v_scale)
self.kv_cache_dtype,
k_scale, v_scale)
(num_prefill_query_tokens, num_prefill_kv_tokens, (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens) = \ num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
@ -580,8 +578,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
prefill_meta.max_query_len, prefill_meta.max_query_len,
self.alibi_slopes, self.alibi_slopes,
self.sliding_window, self.sliding_window,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
assert output[:num_prefill_query_tokens].shape == out.shape assert output[:num_prefill_query_tokens].shape == out.shape
output[:num_prefill_query_tokens] = out output[:num_prefill_query_tokens] = out
@ -607,8 +605,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
# Reshape the output tensor. # Reshape the output tensor.

View File

@ -243,8 +243,7 @@ def unified_attention(
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name] self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(query, key, value, kv_cache, attn_metadata, return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
self._k_scale, self._v_scale)
def unified_attention_fake( def unified_attention_fake(
@ -276,13 +275,12 @@ def unified_attention_with_output(
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
self = forward_context.attn_layers[layer_name] self = forward_context.attn_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine] kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(query, self.impl.forward(self,
query,
key, key,
value, value,
kv_cache, kv_cache,
attn_metadata, attn_metadata,
self._k_scale,
self._v_scale,
output=output) output=output)

View File

@ -130,13 +130,12 @@ class FlashAttentionImpl(AttentionImpl):
def forward( def forward(
self, self,
layer: torch.nn.Module,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
@ -151,7 +150,7 @@ class FlashAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
# NOTE(woosuk): FlashAttention does not support FP8 KV cache. # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, ( assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
@ -183,8 +182,8 @@ class FlashAttentionImpl(AttentionImpl):
value_cache, value_cache,
attn_metadata.slot_mapping, attn_metadata.slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, layer._k_scale,
v_scale, layer._v_scale,
) )
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.