mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-25 08:54:26 +08:00
[Misc] Pass attention to impl backend (#12218)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
5f0ec3935a
commit
86bfb6dba7
@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
||||
Tuple, Type, TypeVar)
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
||||
Protocol, Set, Tuple, Type, TypeVar)
|
||||
|
||||
import torch
|
||||
|
||||
@ -223,6 +223,22 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AttentionLayer(Protocol):
|
||||
|
||||
_k_scale: float
|
||||
_v_scale: float
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
|
||||
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
@ -244,13 +260,12 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
@ -358,13 +359,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -401,8 +401,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
@ -439,8 +439,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
tp_rank=self.tp_rank,
|
||||
blocksparse_local_blocks=self.local_blocks,
|
||||
blocksparse_vert_stride=self.vert_stride,
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
@ -634,13 +635,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
@ -657,7 +657,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
NOTE: It in-place updates the output tensor.
|
||||
"""
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
@ -709,8 +709,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
kv_cache[1],
|
||||
updated_slot_mapping.flatten(), # type: ignore[union-attr]
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
|
||||
@ -23,6 +23,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
@ -792,13 +793,12 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashInferMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -826,8 +826,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache[:, 1],
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
|
||||
# to process the cache when the kv_cache_dtype is fp8
|
||||
@ -886,8 +886,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
causal=True,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
window_left=window_left)
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
assert decode_meta is not None
|
||||
@ -897,8 +897,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
kv_cache,
|
||||
sm_scale=softmax_scale,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
window_left=window_left)
|
||||
|
||||
if prefill_output is None and decode_output is not None:
|
||||
|
||||
@ -11,6 +11,7 @@ import vllm_hpu_extension.ops as ops
|
||||
from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
|
||||
@ -152,13 +153,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: HPUAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
@ -171,13 +172,12 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||
@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
@ -210,8 +210,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
@ -296,8 +296,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
@ -329,8 +329,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
@ -150,13 +151,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: PallasMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
@ -414,13 +415,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -458,8 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
@ -567,8 +567,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window[0],
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
@ -613,8 +613,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
@ -628,8 +628,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
@ -429,13 +430,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
@ -451,7 +451,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert k_scale == 1.0 and v_scale == 1.0
|
||||
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||
attn_type = self.attn_type
|
||||
if (attn_type == AttentionType.ENCODER
|
||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||
@ -493,11 +493,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
# Update self-attention KV cache (prefill/decode)
|
||||
updated_slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||
|
||||
if attn_type != AttentionType.ENCODER:
|
||||
# Decoder self-attention supports chunked prefill.
|
||||
@ -571,8 +569,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
||||
@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import (
|
||||
CommonAttentionState, CommonMetadataBuilder,
|
||||
@ -412,13 +413,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor],
|
||||
value: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: "XFormersMetadata",
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
@ -524,11 +524,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
# If kv_cache is not provided, the new key and value tensors are
|
||||
# not cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||
value_cache,
|
||||
updated_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale, v_scale)
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key, value, key_cache, value_cache, updated_slot_mapping,
|
||||
self.kv_cache_dtype, layer._k_scale, layer._v_scale)
|
||||
(num_prefill_query_tokens, num_prefill_kv_tokens,
|
||||
num_decode_query_tokens) = \
|
||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||
@ -580,8 +578,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
prefill_meta.max_query_len,
|
||||
self.alibi_slopes,
|
||||
self.sliding_window,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
assert output[:num_prefill_query_tokens].shape == out.shape
|
||||
output[:num_prefill_query_tokens] = out
|
||||
@ -607,8 +605,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
|
||||
@ -243,8 +243,7 @@ def unified_attention(
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.attn_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||
self._k_scale, self._v_scale)
|
||||
return self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
@ -276,13 +275,12 @@ def unified_attention_with_output(
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
self = forward_context.attn_layers[layer_name]
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
self.impl.forward(query,
|
||||
self.impl.forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
output=output)
|
||||
|
||||
|
||||
|
||||
@ -130,13 +130,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
@ -151,7 +150,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
@ -183,8 +182,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user