[misc] use out argument for flash attention (#10822)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-12-02 02:50:10 -08:00 committed by GitHub
parent e95f275f57
commit a4c4daf364
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 141 additions and 154 deletions

View File

@ -247,5 +247,6 @@ class AttentionImpl(ABC, Generic[T]):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -360,6 +360,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -448,5 +449,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
blocksparse_head_sliding_step=self.head_sliding_step, blocksparse_head_sliding_step=self.head_sliding_step,
) )
assert output is not None
# Reshape the output tensor. # Reshape the output tensor.
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)

View File

@ -638,24 +638,27 @@ class FlashAttentionImpl(AttentionImpl):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
NOTE: kv_cache will be an empty tensor with shape [0] NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run. for profiling run.
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: NOTE: It in-place updates the output tensor.
shape = [num_tokens, num_heads * head_size]
""" """
# NOTE(woosuk): FlashAttention does not support FP8 KV cache. # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, ( assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
if (attn_type == AttentionType.ENCODER if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)): and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting " raise AttributeError("Encoder attention requires setting "
@ -666,23 +669,12 @@ class FlashAttentionImpl(AttentionImpl):
"requires setting cross-attention " "requires setting cross-attention "
"metadata attributes.") "metadata attributes.")
num_heads: int = self.num_heads
head_size: int = self.head_size
num_kv_heads: int = self.num_kv_heads
kv_cache_dtype: str = self.kv_cache_dtype kv_cache_dtype: str = self.kv_cache_dtype
softmax_scale: float = self.scale softmax_scale: float = self.scale
window_size = self.sliding_window window_size = self.sliding_window
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
logits_soft_cap: Optional[float] = self.logits_soft_cap logits_soft_cap: Optional[float] = self.logits_soft_cap
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
if (key is not None) and (value is not None):
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
if kv_cache.numel() > 0: if kv_cache.numel() > 0:
key_cache = kv_cache[0] key_cache = kv_cache[0]
value_cache = kv_cache[1] value_cache = kv_cache[1]
@ -721,13 +713,13 @@ class FlashAttentionImpl(AttentionImpl):
num_decode_query_tokens) = \ num_decode_query_tokens) = \
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
decode_query = query[num_prefill_query_tokens:] decode_query = query[num_prefill_query_tokens:]
decode_output = output[num_prefill_query_tokens:]
# QKV for prefill. # QKV for prefill.
query = query[:num_prefill_query_tokens] query = query[:num_prefill_query_tokens]
prefill_output = output[:num_prefill_query_tokens]
assert query.shape[0] == num_prefill_query_tokens assert query.shape[0] == num_prefill_query_tokens
assert decode_query.shape[0] == num_decode_query_tokens assert decode_query.shape[0] == num_decode_query_tokens
prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
@ -741,7 +733,7 @@ class FlashAttentionImpl(AttentionImpl):
key = key[:num_prefill_kv_tokens] key = key[:num_prefill_kv_tokens]
value = value[:num_prefill_kv_tokens] value = value[:num_prefill_kv_tokens]
prefill_output = flash_attn_varlen_func( flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
@ -754,6 +746,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output,
) )
else: else:
# prefix-enabled attention # prefix-enabled attention
@ -761,7 +754,7 @@ class FlashAttentionImpl(AttentionImpl):
"Only decoder-only models support prefix caching") "Only decoder-only models support prefix caching")
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa flash_attn_varlen_func( # noqa
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
@ -775,6 +768,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables, block_table=prefill_meta.block_tables,
softcap=logits_soft_cap, softcap=logits_soft_cap,
out=prefill_output,
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
@ -788,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
assert attn_type == AttentionType.DECODER, ( assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support max_decode_query_len > 1" "Only decoder-only models support max_decode_query_len > 1"
) )
decode_output = flash_attn_varlen_func( flash_attn_varlen_func(
q=decode_query, q=decode_query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
@ -802,6 +796,7 @@ class FlashAttentionImpl(AttentionImpl):
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
out=decode_output,
) )
else: else:
# Use flash_attn_with_kvcache for normal decoding. # Use flash_attn_with_kvcache for normal decoding.
@ -810,7 +805,7 @@ class FlashAttentionImpl(AttentionImpl):
_, _,
block_tables_arg, block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type) ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
decode_output = flash_attn_with_kvcache( flash_attn_with_kvcache(
q=decode_query.unsqueeze(1), q=decode_query.unsqueeze(1),
k_cache=key_cache, k_cache=key_cache,
v_cache=value_cache, v_cache=value_cache,
@ -821,20 +816,8 @@ class FlashAttentionImpl(AttentionImpl):
window_size=window_size, window_size=window_size,
alibi_slopes=alibi_slopes, alibi_slopes=alibi_slopes,
softcap=logits_soft_cap, softcap=logits_soft_cap,
).squeeze(1) out=decode_output.unsqueeze(1),
)
if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_query_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_query_tokens, hidden_size)
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
return output return output

View File

@ -774,7 +774,11 @@ class FlashInferImpl(AttentionImpl):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: directly write to output tensor
if attn_type != AttentionType.DECODER: if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and " raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention " "encoder/decoder cross-attention "

View File

@ -145,6 +145,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.

View File

@ -173,6 +173,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention. """Forward pass with IPEX varlen_attention and PagedAttention.

View File

@ -151,6 +151,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.

View File

@ -415,6 +415,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.

View File

@ -431,6 +431,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.

View File

@ -417,6 +417,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: str = AttentionType.DECODER, attn_type: str = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.

View File

@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig, get_current_vllm_config from vllm.config import CacheConfig, get_current_vllm_config
@ -12,7 +11,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -97,14 +96,23 @@ class Attention(nn.Module):
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype, alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap) blocksparse_params, logits_soft_cap)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name()) self.backend = backend_name_to_enum(attn_backend.get_name())
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant # torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them # opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them. # and let torch.compile handle them.
self.use_direct_call = envs.VLLM_USE_V1 or not ( self.use_direct_call = not current_platform.is_cuda_alike(
current_platform.is_cuda_alike() or current_platform.is_cpu()) ) and not current_platform.is_cpu()
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
self.use_output = self.backend == _Backend.FLASH_ATTN or \
self.backend == _Backend.FLASH_ATTN_VLLM_V1
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}") raise ValueError(f"Duplicate layer name: {prefix}")
@ -130,6 +138,22 @@ class Attention(nn.Module):
self._k_scale, self._k_scale,
self._v_scale, self._v_scale,
attn_type=attn_type) attn_type=attn_type)
elif self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, kv_cache, attn_type,
self.layer_name)
return output.view(-1, hidden_size)
else: else:
return torch.ops.vllm.unified_attention(query, key, value, return torch.ops.vllm.unified_attention(query, key, value,
kv_cache, attn_type, kv_cache, attn_type,
@ -183,3 +207,47 @@ direct_register_custom_op(
fake_impl=unified_attention_fake, fake_impl=unified_attention_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.dynamic_forward_context
self = forward_context.static_forward_context[layer_name]
self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type,
output=output)
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
kv_cache: torch.Tensor,
attn_type: str,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["kv_cache", "output"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -2238,7 +2238,7 @@ class CompilationConfig(BaseModel):
custom_ops: List[str] = Field(default_factory=list) custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [ splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_attention", "vllm.unified_attention",
"vllm.unified_v1_flash_attention", "vllm.unified_attention_with_output",
]) ])
use_inductor: bool = True use_inductor: bool = True

View File

@ -6,8 +6,6 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
@ -113,13 +111,14 @@ class FlashAttentionImpl(AttentionImpl):
k_scale: float = 1.0, k_scale: float = 1.0,
v_scale: float = 1.0, v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
Args: Args:
query: shape = [num_tokens, num_heads * head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
@ -135,118 +134,42 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, ( assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
# Reshape the query, key, and value tensors. if attn_metadata is None:
# NOTE(woosuk): We do this outside the custom op to minimize the CPU # Profiling run.
# overheads from the non-CUDA-graph regions. return output
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
output = torch.empty_like(query) num_actual_tokens = attn_metadata.num_actual_tokens
torch.ops.vllm.unified_v1_flash_attention(
output, # Reshape the input keys and values and store them in the cache.
query, key_cache = kv_cache[0]
key, value_cache = kv_cache[1]
value, torch.ops._C_cache_ops.reshape_and_cache_flash(
self.num_heads, key[:num_actual_tokens],
self.head_size, value[:num_actual_tokens],
self.num_kv_heads, key_cache,
kv_cache, value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
k_scale, k_scale,
v_scale, v_scale,
self.scale,
self.sliding_window,
self.alibi_slopes,
self.logits_soft_cap,
) )
return output.view(-1, self.num_heads * self.head_size)
# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=attn_metadata.block_table,
softcap=self.logits_soft_cap,
)
def unified_v1_flash_attention( return output
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
context = get_forward_context()
current_metadata = context.dynamic_forward_context
if current_metadata is None:
# Profiling run.
return
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
# Compute attention and update output up to `num_actual_tokens`.
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
def unified_v1_flash_attention_fake(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
kv_cache: torch.Tensor,
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
softmax_scale: float,
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_v1_flash_attention",
op_func=unified_v1_flash_attention,
mutates_args=["kv_cache", "output"],
fake_impl=unified_v1_flash_attention_fake,
)