mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 04:55:45 +08:00
[Misc] Pass attention to impl backend (#12218)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
5f0ec3935a
commit
86bfb6dba7
@ -1,8 +1,8 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
||||||
Tuple, Type, TypeVar)
|
Protocol, Set, Tuple, Type, TypeVar)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -223,6 +223,22 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLayer(Protocol):
|
||||||
|
|
||||||
|
_k_scale: float
|
||||||
|
_v_scale: float
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class AttentionImpl(ABC, Generic[T]):
|
class AttentionImpl(ABC, Generic[T]):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -244,13 +260,12 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: T,
|
attn_metadata: T,
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||||
CommonMetadataBuilder)
|
CommonMetadataBuilder)
|
||||||
@ -358,13 +359,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
@ -401,8 +401,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
@ -439,8 +439,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
blocksparse_local_blocks=self.local_blocks,
|
blocksparse_local_blocks=self.local_blocks,
|
||||||
blocksparse_vert_stride=self.vert_stride,
|
blocksparse_vert_stride=self.vert_stride,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType)
|
AttentionType)
|
||||||
@ -634,13 +635,12 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
@ -657,7 +657,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
NOTE: It in-place updates the output tensor.
|
NOTE: It in-place updates the output tensor.
|
||||||
"""
|
"""
|
||||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
@ -709,8 +709,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
kv_cache[1],
|
kv_cache[1],
|
||||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
|
|||||||
@ -23,6 +23,7 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionState, AttentionType)
|
AttentionState, AttentionType)
|
||||||
@ -792,13 +793,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: FlashInferMetadata,
|
attn_metadata: FlashInferMetadata,
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
@ -826,8 +826,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
kv_cache[:, 1],
|
kv_cache[:, 1],
|
||||||
attn_metadata.slot_mapping.flatten(),
|
attn_metadata.slot_mapping.flatten(),
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||||
# to process the cache when the kv_cache_dtype is fp8
|
# to process the cache when the kv_cache_dtype is fp8
|
||||||
@ -886,8 +886,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
causal=True,
|
causal=True,
|
||||||
k_scale=k_scale,
|
k_scale=layer._k_scale,
|
||||||
v_scale=v_scale,
|
v_scale=layer._v_scale,
|
||||||
window_left=window_left)
|
window_left=window_left)
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
assert decode_meta is not None
|
assert decode_meta is not None
|
||||||
@ -897,8 +897,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
kv_cache,
|
kv_cache,
|
||||||
sm_scale=softmax_scale,
|
sm_scale=softmax_scale,
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
k_scale=k_scale,
|
k_scale=layer._k_scale,
|
||||||
v_scale=v_scale,
|
v_scale=layer._v_scale,
|
||||||
window_left=window_left)
|
window_left=window_left)
|
||||||
|
|
||||||
if prefill_output is None and decode_output is not None:
|
if prefill_output is None and decode_output is not None:
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import vllm_hpu_extension.ops as ops
|
|||||||
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
||||||
@ -152,13 +153,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: HPUAttentionMetadata,
|
attn_metadata: HPUAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm._ipex_ops import ipex_ops
|
from vllm._ipex_ops import ipex_ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
@ -171,13 +172,12 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||||
@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert k_scale == 1.0 and v_scale == 1.0
|
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
@ -210,8 +210,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping.flatten(),
|
attn_metadata.slot_mapping.flatten(),
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
@ -296,8 +296,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -329,8 +329,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import torch
|
|||||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
|
||||||
@ -150,13 +151,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
attn_metadata: PallasMetadata,
|
attn_metadata: PallasMetadata,
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Pallas attention.
|
"""Forward pass with Pallas attention.
|
||||||
@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [batch_size, seq_len, num_heads * head_size]
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert k_scale == 1.0 and v_scale == 1.0
|
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||||
batch_size, seq_len, hidden_size = query.shape
|
batch_size, seq_len, hidden_size = query.shape
|
||||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||||
CommonMetadataBuilder)
|
CommonMetadataBuilder)
|
||||||
@ -414,13 +415,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: ROCmFlashAttentionMetadata,
|
attn_metadata: ROCmFlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
@ -458,8 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
@ -567,8 +567,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.sliding_window[0],
|
self.sliding_window[0],
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
@ -613,8 +613,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||||
@ -628,8 +628,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
AttentionType)
|
AttentionType)
|
||||||
@ -429,13 +430,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
@ -451,7 +451,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert k_scale == 1.0 and v_scale == 1.0
|
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||||
attn_type = self.attn_type
|
attn_type = self.attn_type
|
||||||
if (attn_type == AttentionType.ENCODER
|
if (attn_type == AttentionType.ENCODER
|
||||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
@ -493,11 +493,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
# Update self-attention KV cache (prefill/decode)
|
# Update self-attention KV cache (prefill/decode)
|
||||||
updated_slot_mapping = attn_metadata.slot_mapping
|
updated_slot_mapping = attn_metadata.slot_mapping
|
||||||
|
|
||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
PagedAttention.write_to_paged_cache(
|
||||||
value_cache,
|
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||||
updated_slot_mapping,
|
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||||
self.kv_cache_dtype,
|
|
||||||
k_scale, v_scale)
|
|
||||||
|
|
||||||
if attn_type != AttentionType.ENCODER:
|
if attn_type != AttentionType.ENCODER:
|
||||||
# Decoder self-attention supports chunked prefill.
|
# Decoder self-attention supports chunked prefill.
|
||||||
@ -571,8 +569,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
|||||||
LowerTriangularMaskWithTensorBias)
|
LowerTriangularMaskWithTensorBias)
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.attention.backends.utils import (
|
from vllm.attention.backends.utils import (
|
||||||
CommonAttentionState, CommonMetadataBuilder,
|
CommonAttentionState, CommonMetadataBuilder,
|
||||||
@ -412,13 +413,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: Optional[torch.Tensor],
|
key: Optional[torch.Tensor],
|
||||||
value: Optional[torch.Tensor],
|
value: Optional[torch.Tensor],
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: "XFormersMetadata",
|
attn_metadata: "XFormersMetadata",
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
@ -524,11 +524,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
# If kv_cache is not provided, the new key and value tensors are
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
# not cached. This happens during the initial memory
|
# not cached. This happens during the initial memory
|
||||||
# profiling run.
|
# profiling run.
|
||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
PagedAttention.write_to_paged_cache(
|
||||||
value_cache,
|
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||||
updated_slot_mapping,
|
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||||
self.kv_cache_dtype,
|
|
||||||
k_scale, v_scale)
|
|
||||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||||
num_decode_query_tokens) = \
|
num_decode_query_tokens) = \
|
||||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||||
@ -580,8 +578,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.sliding_window,
|
self.sliding_window,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
assert output[:num_prefill_query_tokens].shape == out.shape
|
assert output[:num_prefill_query_tokens].shape == out.shape
|
||||||
output[:num_prefill_query_tokens] = out
|
output[:num_prefill_query_tokens] = out
|
||||||
@ -607,8 +605,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
|||||||
@ -243,8 +243,7 @@ def unified_attention(
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
self = forward_context.attn_layers[layer_name]
|
self = forward_context.attn_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||||
self._k_scale, self._v_scale)
|
|
||||||
|
|
||||||
|
|
||||||
def unified_attention_fake(
|
def unified_attention_fake(
|
||||||
@ -276,13 +275,12 @@ def unified_attention_with_output(
|
|||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
self = forward_context.attn_layers[layer_name]
|
self = forward_context.attn_layers[layer_name]
|
||||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||||
self.impl.forward(query,
|
self.impl.forward(self,
|
||||||
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self._k_scale,
|
|
||||||
self._v_scale,
|
|
||||||
output=output)
|
output=output)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -130,13 +130,12 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
k_scale: float = 1.0,
|
|
||||||
v_scale: float = 1.0,
|
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
@ -151,7 +150,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
@ -183,8 +182,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
layer._k_scale,
|
||||||
v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user