mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 08:45:36 +08:00
[misc] use out argument for flash attention (#10822)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
e95f275f57
commit
a4c4daf364
@ -247,5 +247,6 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -360,6 +360,7 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
@ -448,5 +449,6 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
blocksparse_head_sliding_step=self.head_sliding_step,
|
blocksparse_head_sliding_step=self.head_sliding_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert output is not None
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|||||||
@ -638,24 +638,27 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
|
output: shape = [num_tokens, num_heads, head_size]
|
||||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||||
for profiling run.
|
for profiling run.
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
NOTE: It in-place updates the output tensor.
|
||||||
shape = [num_tokens, num_heads * head_size]
|
|
||||||
"""
|
"""
|
||||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if (attn_type == AttentionType.ENCODER
|
if (attn_type == AttentionType.ENCODER
|
||||||
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
|
||||||
raise AttributeError("Encoder attention requires setting "
|
raise AttributeError("Encoder attention requires setting "
|
||||||
@ -666,23 +669,12 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"requires setting cross-attention "
|
"requires setting cross-attention "
|
||||||
"metadata attributes.")
|
"metadata attributes.")
|
||||||
|
|
||||||
num_heads: int = self.num_heads
|
|
||||||
head_size: int = self.head_size
|
|
||||||
num_kv_heads: int = self.num_kv_heads
|
|
||||||
kv_cache_dtype: str = self.kv_cache_dtype
|
kv_cache_dtype: str = self.kv_cache_dtype
|
||||||
softmax_scale: float = self.scale
|
softmax_scale: float = self.scale
|
||||||
window_size = self.sliding_window
|
window_size = self.sliding_window
|
||||||
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes
|
||||||
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
logits_soft_cap: Optional[float] = self.logits_soft_cap
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
|
||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
|
||||||
query = query.view(-1, num_heads, head_size)
|
|
||||||
if (key is not None) and (value is not None):
|
|
||||||
key = key.view(-1, num_kv_heads, head_size)
|
|
||||||
value = value.view(-1, num_kv_heads, head_size)
|
|
||||||
|
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
key_cache = kv_cache[0]
|
key_cache = kv_cache[0]
|
||||||
value_cache = kv_cache[1]
|
value_cache = kv_cache[1]
|
||||||
@ -721,13 +713,13 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
num_decode_query_tokens) = \
|
num_decode_query_tokens) = \
|
||||||
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type)
|
||||||
decode_query = query[num_prefill_query_tokens:]
|
decode_query = query[num_prefill_query_tokens:]
|
||||||
|
decode_output = output[num_prefill_query_tokens:]
|
||||||
# QKV for prefill.
|
# QKV for prefill.
|
||||||
query = query[:num_prefill_query_tokens]
|
query = query[:num_prefill_query_tokens]
|
||||||
|
prefill_output = output[:num_prefill_query_tokens]
|
||||||
assert query.shape[0] == num_prefill_query_tokens
|
assert query.shape[0] == num_prefill_query_tokens
|
||||||
assert decode_query.shape[0] == num_decode_query_tokens
|
assert decode_query.shape[0] == num_decode_query_tokens
|
||||||
|
|
||||||
prefill_output: Optional[torch.Tensor] = None
|
|
||||||
decode_output: Optional[torch.Tensor] = None
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
# Prompt run.
|
# Prompt run.
|
||||||
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
|
||||||
@ -741,7 +733,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
key = key[:num_prefill_kv_tokens]
|
key = key[:num_prefill_kv_tokens]
|
||||||
value = value[:num_prefill_kv_tokens]
|
value = value[:num_prefill_kv_tokens]
|
||||||
|
|
||||||
prefill_output = flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
@ -754,6 +746,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
|
out=prefill_output,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention
|
# prefix-enabled attention
|
||||||
@ -761,7 +754,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"Only decoder-only models support prefix caching")
|
"Only decoder-only models support prefix caching")
|
||||||
assert prefill_meta.seq_lens is not None
|
assert prefill_meta.seq_lens is not None
|
||||||
max_seq_len = max(prefill_meta.seq_lens)
|
max_seq_len = max(prefill_meta.seq_lens)
|
||||||
prefill_output = flash_attn_varlen_func( # noqa
|
flash_attn_varlen_func( # noqa
|
||||||
q=query,
|
q=query,
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
@ -775,6 +768,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
block_table=prefill_meta.block_tables,
|
block_table=prefill_meta.block_tables,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
|
out=prefill_output,
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
@ -788,7 +782,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert attn_type == AttentionType.DECODER, (
|
assert attn_type == AttentionType.DECODER, (
|
||||||
"Only decoder-only models support max_decode_query_len > 1"
|
"Only decoder-only models support max_decode_query_len > 1"
|
||||||
)
|
)
|
||||||
decode_output = flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=decode_query,
|
q=decode_query,
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
@ -802,6 +796,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
block_table=decode_meta.block_tables,
|
block_table=decode_meta.block_tables,
|
||||||
|
out=decode_output,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Use flash_attn_with_kvcache for normal decoding.
|
# Use flash_attn_with_kvcache for normal decoding.
|
||||||
@ -810,7 +805,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
_,
|
_,
|
||||||
block_tables_arg,
|
block_tables_arg,
|
||||||
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
) = get_seq_len_block_table_args(decode_meta, False, attn_type)
|
||||||
decode_output = flash_attn_with_kvcache(
|
flash_attn_with_kvcache(
|
||||||
q=decode_query.unsqueeze(1),
|
q=decode_query.unsqueeze(1),
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
@ -821,20 +816,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
alibi_slopes=alibi_slopes,
|
alibi_slopes=alibi_slopes,
|
||||||
softcap=logits_soft_cap,
|
softcap=logits_soft_cap,
|
||||||
).squeeze(1)
|
out=decode_output.unsqueeze(1),
|
||||||
|
)
|
||||||
if prefill_output is None:
|
|
||||||
assert decode_output is not None
|
|
||||||
return decode_output.view(num_decode_query_tokens, hidden_size)
|
|
||||||
if decode_output is None:
|
|
||||||
assert prefill_output is not None
|
|
||||||
return prefill_output.view(num_prefill_query_tokens, hidden_size)
|
|
||||||
|
|
||||||
assert decode_meta is not None
|
|
||||||
decode_output = decode_output.squeeze(1)
|
|
||||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
|
||||||
return output.view(num_tokens, hidden_size)
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -774,7 +774,11 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# TODO: directly write to output tensor
|
||||||
|
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
|
|||||||
@ -145,6 +145,7 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -173,6 +173,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -151,6 +151,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Pallas attention.
|
"""Forward pass with Pallas attention.
|
||||||
|
|
||||||
|
|||||||
@ -415,6 +415,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -431,6 +431,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -417,6 +417,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.attention import AttentionMetadata, AttentionType
|
from vllm.attention import AttentionMetadata, AttentionType
|
||||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
@ -12,7 +11,7 @@ from vllm.forward_context import ForwardContext, get_forward_context
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
@ -97,14 +96,23 @@ class Attention(nn.Module):
|
|||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params, logits_soft_cap)
|
blocksparse_params, logits_soft_cap)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
self.backend = backend_name_to_enum(attn_backend.get_name())
|
self.backend = backend_name_to_enum(attn_backend.get_name())
|
||||||
|
|
||||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||||
# torch.compile works by registering the attention as one giant
|
# torch.compile works by registering the attention as one giant
|
||||||
# opaque custom op. For other platforms, we directly call them
|
# opaque custom op. For other platforms, we directly call them
|
||||||
# and let torch.compile handle them.
|
# and let torch.compile handle them.
|
||||||
self.use_direct_call = envs.VLLM_USE_V1 or not (
|
self.use_direct_call = not current_platform.is_cuda_alike(
|
||||||
current_platform.is_cuda_alike() or current_platform.is_cpu())
|
) and not current_platform.is_cpu()
|
||||||
|
|
||||||
|
# For some attention backends, we allocate an output tensor before
|
||||||
|
# calling the custom op. When piecewise cudagraph is enabled, this
|
||||||
|
# makes sure the output tensor is allocated inside the cudagraph.
|
||||||
|
self.use_output = self.backend == _Backend.FLASH_ATTN or \
|
||||||
|
self.backend == _Backend.FLASH_ATTN_VLLM_V1
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
if prefix in compilation_config.static_forward_context:
|
if prefix in compilation_config.static_forward_context:
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
@ -130,6 +138,22 @@ class Attention(nn.Module):
|
|||||||
self._k_scale,
|
self._k_scale,
|
||||||
self._v_scale,
|
self._v_scale,
|
||||||
attn_type=attn_type)
|
attn_type=attn_type)
|
||||||
|
elif self.use_output:
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
hidden_size = query.size(-1)
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||||
|
# CPU overheads from the non-CUDA-graph regions.
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
output = output.view(-1, self.num_heads, self.head_size)
|
||||||
|
if key is not None:
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
if value is not None:
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
torch.ops.vllm.unified_attention_with_output(
|
||||||
|
query, key, value, output, kv_cache, attn_type,
|
||||||
|
self.layer_name)
|
||||||
|
return output.view(-1, hidden_size)
|
||||||
else:
|
else:
|
||||||
return torch.ops.vllm.unified_attention(query, key, value,
|
return torch.ops.vllm.unified_attention(query, key, value,
|
||||||
kv_cache, attn_type,
|
kv_cache, attn_type,
|
||||||
@ -183,3 +207,47 @@ direct_register_custom_op(
|
|||||||
fake_impl=unified_attention_fake,
|
fake_impl=unified_attention_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_with_output(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_type: str,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.dynamic_forward_context
|
||||||
|
self = forward_context.static_forward_context[layer_name]
|
||||||
|
self.impl.forward(query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
self._k_scale,
|
||||||
|
self._v_scale,
|
||||||
|
attn_type=attn_type,
|
||||||
|
output=output)
|
||||||
|
|
||||||
|
|
||||||
|
def unified_attention_with_output_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_type: str,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="unified_attention_with_output",
|
||||||
|
op_func=unified_attention_with_output,
|
||||||
|
mutates_args=["kv_cache", "output"],
|
||||||
|
fake_impl=unified_attention_with_output_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|||||||
@ -2238,7 +2238,7 @@ class CompilationConfig(BaseModel):
|
|||||||
custom_ops: List[str] = Field(default_factory=list)
|
custom_ops: List[str] = Field(default_factory=list)
|
||||||
splitting_ops: List[str] = Field(default_factory=lambda: [
|
splitting_ops: List[str] = Field(default_factory=lambda: [
|
||||||
"vllm.unified_attention",
|
"vllm.unified_attention",
|
||||||
"vllm.unified_v1_flash_attention",
|
"vllm.unified_attention_with_output",
|
||||||
])
|
])
|
||||||
|
|
||||||
use_inductor: bool = True
|
use_inductor: bool = True
|
||||||
|
|||||||
@ -6,8 +6,6 @@ import torch
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType)
|
AttentionMetadata, AttentionType)
|
||||||
from vllm.forward_context import get_forward_context
|
|
||||||
from vllm.utils import direct_register_custom_op
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
|
|
||||||
@ -113,13 +111,14 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
k_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
v_scale: float = 1.0,
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
@ -135,118 +134,42 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
"key/v_scale is not supported in FlashAttention.")
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
# Reshape the query, key, and value tensors.
|
if attn_metadata is None:
|
||||||
# NOTE(woosuk): We do this outside the custom op to minimize the CPU
|
# Profiling run.
|
||||||
# overheads from the non-CUDA-graph regions.
|
return output
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
|
||||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
|
||||||
|
|
||||||
output = torch.empty_like(query)
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
torch.ops.vllm.unified_v1_flash_attention(
|
|
||||||
output,
|
# Reshape the input keys and values and store them in the cache.
|
||||||
query,
|
key_cache = kv_cache[0]
|
||||||
key,
|
value_cache = kv_cache[1]
|
||||||
value,
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
self.num_heads,
|
key[:num_actual_tokens],
|
||||||
self.head_size,
|
value[:num_actual_tokens],
|
||||||
self.num_kv_heads,
|
key_cache,
|
||||||
kv_cache,
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
self.scale,
|
|
||||||
self.sliding_window,
|
|
||||||
self.alibi_slopes,
|
|
||||||
self.logits_soft_cap,
|
|
||||||
)
|
)
|
||||||
return output.view(-1, self.num_heads * self.head_size)
|
|
||||||
|
|
||||||
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
|
flash_attn_varlen_func(
|
||||||
|
q=query[:num_actual_tokens],
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
out=output[:num_actual_tokens],
|
||||||
|
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||||
|
max_seqlen_q=attn_metadata.max_query_len,
|
||||||
|
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||||
|
max_seqlen_k=attn_metadata.max_seq_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
block_table=attn_metadata.block_table,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
)
|
||||||
|
|
||||||
def unified_v1_flash_attention(
|
return output
|
||||||
output: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
k_scale: float,
|
|
||||||
v_scale: float,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size: Optional[List[int]] = None,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
logits_soft_cap: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
context = get_forward_context()
|
|
||||||
current_metadata = context.dynamic_forward_context
|
|
||||||
if current_metadata is None:
|
|
||||||
# Profiling run.
|
|
||||||
return
|
|
||||||
|
|
||||||
assert current_metadata is not None
|
|
||||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
|
||||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
|
||||||
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
key_cache = kv_cache[0]
|
|
||||||
value_cache = kv_cache[1]
|
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
|
||||||
key[:num_actual_tokens],
|
|
||||||
value[:num_actual_tokens],
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
attn_metadata.slot_mapping,
|
|
||||||
kv_cache_dtype,
|
|
||||||
k_scale,
|
|
||||||
v_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
|
||||||
flash_attn_varlen_func(
|
|
||||||
q=query[:num_actual_tokens],
|
|
||||||
k=key_cache,
|
|
||||||
v=value_cache,
|
|
||||||
out=output[:num_actual_tokens],
|
|
||||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
|
||||||
max_seqlen_q=attn_metadata.max_query_len,
|
|
||||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
|
||||||
max_seqlen_k=attn_metadata.max_seq_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
alibi_slopes=alibi_slopes,
|
|
||||||
window_size=window_size,
|
|
||||||
block_table=attn_metadata.block_table,
|
|
||||||
softcap=logits_soft_cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def unified_v1_flash_attention_fake(
|
|
||||||
output: torch.Tensor,
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
k_scale: float,
|
|
||||||
v_scale: float,
|
|
||||||
softmax_scale: float,
|
|
||||||
window_size: Optional[List[int]] = None,
|
|
||||||
alibi_slopes: Optional[torch.Tensor] = None,
|
|
||||||
logits_soft_cap: Optional[float] = None,
|
|
||||||
) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
direct_register_custom_op(
|
|
||||||
op_name="unified_v1_flash_attention",
|
|
||||||
op_func=unified_v1_flash_attention,
|
|
||||||
mutates_args=["kv_cache", "output"],
|
|
||||||
fake_impl=unified_v1_flash_attention_fake,
|
|
||||||
)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user