[ROCm][AITER] Enable fp8 kv cache on rocm aiter backend. (#20295)

Signed-off-by: fsx950223 <fsx950223@outlook.com>
Signed-off-by: amd-ruitang3 <Rui.Tang2@amd.com>
Co-authored-by: amd-ruitang3 <Rui.Tang2@amd.com>
This commit is contained in:
who who who 2025-07-25 21:50:21 +08:00 committed by GitHub
parent eab2f3980c
commit b3caeb82e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 320 additions and 96 deletions

View File

@ -0,0 +1,191 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
from vllm.platforms import current_platform
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: list[int],
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: list[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="Only ROCm is supported")
@pytest.mark.parametrize("seq_lens",
[[(10, 1328), (5, 18),
(129, 463)], [(8, 523), (24, 37), (3, 2011)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode()
def test_varlen_with_paged_kv(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
q_dtype: Optional[torch.dtype],
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_seq_lens = torch.tensor([0] + kv_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = torch.empty_like(query)
maybe_quantized_query = query
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
torch.ops.vllm.flash_attn_varlen_func(
maybe_quantized_query,
maybe_quantized_key_cache,
maybe_quantized_value_cache,
out=output,
cu_seqlens_q=cu_query_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
alibi_slopes=None,
window_size=window_size,
block_table=block_tables,
cu_seqlens_k=cu_seq_lens,
k_scale=k_descale,
v_scale=v_descale,
)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)
atol, rtol = 2e-2, 2e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"

View File

@ -2,20 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with AiterFlashAttention.""" """Attention layer with AiterFlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import ClassVar, Optional
import torch import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType, AttentionMetadata, AttentionType)
is_quantized_kv_cache)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
_PARTITION_SIZE_ROCM = 256
if current_platform.is_rocm(): if current_platform.is_rocm():
import aiter import aiter
@ -32,38 +33,54 @@ if current_platform.is_rocm():
b_seq_lens_loc, b_seq_lens_loc,
block_table, block_table,
block_table_stride_0, block_table_stride_0,
k_scale,
v_scale,
output_dtype: tl.constexpr,
E_DIM: tl.constexpr, E_DIM: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) batch_idx = tl.program_id(0)
block_idx = tl.program_id(1) block_idx = tl.program_id(1)
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
tl.arange(0, 2))
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
seq_len = batch_token_end - batch_token_start
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + batch_query_indexes = tl.load(b_query_lens_loc + batch_idx +
tl.arange(0, 2)) tl.arange(0, 2))
batch_query_start, batch_query_end = tl.split(batch_query_indexes) batch_query_start, batch_query_end = tl.split(batch_query_indexes)
query_len = batch_query_end - batch_query_start query_len = batch_query_end - batch_query_start
if query_len <= 1: if query_len <= 1:
return return
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
tl.arange(0, 2))
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
seq_len = batch_token_end - batch_token_start
if block_idx * BLOCK_SIZE < seq_len: if block_idx * BLOCK_SIZE < seq_len:
block_mask = (block_idx * BLOCK_SIZE + block_mask = (block_idx * BLOCK_SIZE +
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 + kv_idx = tl.load(block_table + batch_idx * block_table_stride_0 +
block_idx) block_idx).to(tl.int64)
kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange( kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl.arange(
0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :] 0, BLOCK_SIZE)[:, None] * E_DIM + tl.arange(0, E_DIM)[None, :]
k_vals = tl.load(k_buffer_ptr + kv_buffer_off, k_vals = tl.load(k_buffer_ptr + kv_buffer_off,
mask=block_mask, mask=block_mask,
other=0.0) other=0.0)
if k_vals.dtype.is_fp8():
k_vals = (k_vals.to(tl.float32) *
tl.load(k_scale)).to(output_dtype)
else:
k_vals = k_vals.to(output_dtype)
v_vals = tl.load(v_buffer_ptr + kv_buffer_off, v_vals = tl.load(v_buffer_ptr + kv_buffer_off,
mask=block_mask, mask=block_mask,
other=0.0) other=0.0)
if v_vals.dtype.is_fp8():
v_vals = (v_vals.to(tl.float32) *
tl.load(v_scale)).to(output_dtype)
else:
v_vals = v_vals.to(output_dtype)
kv_values_off = batch_token_start * E_DIM + \ kv_values_off = batch_token_start * E_DIM + \
block_idx * BLOCK_SIZE * E_DIM + \ block_idx * BLOCK_SIZE * E_DIM + \
tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \ tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM + \
@ -72,29 +89,44 @@ if current_platform.is_rocm():
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask) tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table, def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
k_buffer, v_buffer, max_seq_len, total_tokens): k_cache, v_cache, max_seq_len, k_scale, v_scale,
H_KV = v_buffer.shape[2] output_dtype, total_tokens):
D = v_buffer.shape[3] H_KV = v_cache.shape[2]
BLOCK_SIZE = v_buffer.shape[1] D = v_cache.shape[3]
dtype = k_buffer.dtype BLOCK_SIZE = v_cache.shape[1]
k_values = torch.empty((total_tokens, H_KV, D),
dtype=dtype, k_values = torch.empty(
device="cuda") (total_tokens, H_KV, D),
v_values = torch.empty((total_tokens, H_KV, D), dtype=output_dtype,
dtype=dtype, device=k_cache.device,
device="cuda") )
v_values = torch.empty(
(total_tokens, H_KV, D),
dtype=output_dtype,
device=v_cache.device,
)
grid = (block_table.shape[0], grid = (block_table.shape[0],
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE) (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
_vllm_layout_trans_kernel[grid](k_buffer, if output_dtype == torch.float16:
v_buffer, output_dtype = tl.float16
elif output_dtype == torch.bfloat16:
output_dtype = tl.bfloat16
else:
raise ValueError(f"Unsupported output dtype: {output_dtype}")
_vllm_layout_trans_kernel[grid](k_cache,
v_cache,
k_values, k_values,
v_values, v_values,
b_query_lens_loc, b_query_lens_loc,
b_seq_lens_loc, b_seq_lens_loc,
block_table, block_table,
block_table.stride(0), block_table.stride(0),
k_scale,
v_scale,
output_dtype=output_dtype,
E_DIM=H_KV * D, E_DIM=H_KV * D,
BLOCK_SIZE=BLOCK_SIZE) BLOCK_SIZE=BLOCK_SIZE)
@ -107,16 +139,22 @@ if current_platform.is_rocm():
out: torch.Tensor, out: torch.Tensor,
cu_seqlens_q: torch.Tensor, cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor, cu_seqlens_k: torch.Tensor,
total_tokens: int,
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_k: int, max_seqlen_k: int,
softmax_scale: float, softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]], alibi_slopes: Optional[list[float]],
block_table: torch.Tensor, block_table: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
total_tokens: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if total_tokens == 0:
total_tokens = int(cu_seqlens_k[-1].item())
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table, k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
k_cache, v_cache, max_seqlen_k, total_tokens) k_cache, v_cache, max_seqlen_k, k_scale,
v_scale, q.dtype, total_tokens)
output = aiter.flash_attn_varlen_func( output = aiter.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
@ -141,19 +179,21 @@ if current_platform.is_rocm():
out: torch.Tensor, out: torch.Tensor,
cu_seqlens_q: torch.Tensor, cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor, cu_seqlens_k: torch.Tensor,
total_tokens: int,
max_seqlen_q: int, max_seqlen_q: int,
max_seqlen_k: int, max_seqlen_k: int,
softmax_scale: float, softmax_scale: float,
window_size: Optional[list[int]], # -1 means infinite context window window_size: Optional[list[int]], # -1 means infinite context window
alibi_slopes: Optional[list[float]], alibi_slopes: Optional[list[float]],
block_table: torch.Tensor, block_table: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
total_tokens: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty(q.shape[0], return torch.empty(q.shape[0],
q.shape[1], q.shape[1],
v_cache.shape[-2], v_cache.shape[-2],
dtype=torch.float8_e4m3fnuz, dtype=q.dtype,
device="cuda") device=q.device)
direct_register_custom_op("flash_attn_varlen_func", direct_register_custom_op("flash_attn_varlen_func",
flash_attn_varlen_func_impl, ["out"], flash_attn_varlen_func_impl, ["out"],
@ -163,7 +203,33 @@ if current_platform.is_rocm():
logger = init_logger(__name__) logger = init_logger(__name__)
class AiterFlashAttentionMetadataBuilder: @dataclass
class AiterFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
total_tokens: int
class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]):
full_cudagraph_supported: ClassVar[bool] = True
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device): device: torch.device):
@ -180,14 +246,23 @@ class AiterFlashAttentionMetadataBuilder:
self.headdim = self.model_config.get_head_size() self.headdim = self.model_config.get_head_size()
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
# Sliding window size to be used with the AOT scheduler will be # Sliding window size to be used with the AOT scheduler will be
# populated on first build() call. # populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None self.aot_sliding_window: Optional[tuple[int, int]] = None
self.total_tokens: int = 0
def reorder_batch(self, input_batch, scheduler_output) -> bool: def reorder_batch(self, input_batch, scheduler_output) -> bool:
return False return False
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
self.total_tokens = self.model_config.max_model_len \
* self.vllm_config.scheduler_config.max_num_partial_prefills
res = self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
self.total_tokens = 0
return res
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
@ -195,43 +270,29 @@ class AiterFlashAttentionMetadataBuilder:
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len max_query_len = common_attn_metadata.max_query_len
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
total_tokens = int(common_attn_metadata.seq_lens_cpu.sum())
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping slot_mapping = common_attn_metadata.slot_mapping
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
dtype=torch.int32, max_seq_len, causal):
device=self.device) return None
torch.cumsum(seq_lens,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
use_cascade = common_prefix_len > 0 use_cascade = common_prefix_len > 0
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = AiterFlashAttentionMetadata( attn_metadata = AiterFlashAttentionMetadata(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_lens=seq_lens, seq_lens=seq_lens,
cu_seq_lens=cu_seq_lens,
total_tokens=total_tokens,
block_table=block_table_tensor, block_table=block_table_tensor,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
use_cascade=use_cascade, use_cascade=use_cascade,
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens, total_tokens=self.total_tokens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
) )
return attn_metadata return attn_metadata
@ -254,7 +315,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
@classmethod @classmethod
def get_supported_head_sizes(cls) -> list[int]: def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256] return [64, 128, 256]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
@ -295,34 +356,6 @@ class AiterFlashAttentionBackend(AttentionBackend):
return (2, num_blocks, block_size, num_kv_heads, head_size) return (2, num_blocks, block_size, num_kv_heads, head_size)
@dataclass
class AiterFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
cu_seq_lens: torch.Tensor
total_tokens: int
block_table: torch.Tensor
slot_mapping: torch.Tensor
# For cascade attention.
use_cascade: bool
common_prefix_len: int
cu_prefix_query_lens: Optional[torch.Tensor]
prefix_kv_lens: Optional[torch.Tensor]
suffix_kv_lens: Optional[torch.Tensor]
class AiterFlashAttentionImpl(AttentionImpl): class AiterFlashAttentionImpl(AttentionImpl):
def __init__( def __init__(
@ -366,10 +399,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
"encoder/decoder cross-attention " "encoder/decoder cross-attention "
"are not implemented for " "are not implemented for "
"FlashAttentionImpl") "FlashAttentionImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"AiterFlashAttention does not support fp8 kv-cache on this "
"device.")
def forward( def forward(
self, self,
@ -440,12 +469,6 @@ class AiterFlashAttentionImpl(AttentionImpl):
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fnuz) key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz) value_cache = value_cache.view(torch.float8_e4m3fnuz)
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
if not attn_metadata.use_cascade: if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc cu_seqlens_q = attn_metadata.query_start_loc
@ -455,8 +478,16 @@ class AiterFlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table block_table = attn_metadata.block_table
if max_seqlen_q > 1: if max_seqlen_q > 1:
cu_seq_lens = attn_metadata.cu_seq_lens
total_tokens = attn_metadata.total_tokens cu_seq_lens = torch.zeros(seqused_k.shape[0] + 1,
dtype=torch.int32,
device=query.device)
torch.cumsum(seqused_k,
dim=0,
dtype=cu_seq_lens.dtype,
out=cu_seq_lens[1:])
torch.ops.vllm.flash_attn_varlen_func( torch.ops.vllm.flash_attn_varlen_func(
query[:num_actual_tokens], query[:num_actual_tokens],
key_cache, key_cache,
@ -465,29 +496,31 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k, max_seqlen_k=max_seqlen_k,
total_tokens=total_tokens,
softmax_scale=self.scale, softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window, window_size=self.sliding_window,
block_table=block_table, block_table=block_table,
cu_seqlens_k=cu_seq_lens) cu_seqlens_k=cu_seq_lens,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
total_tokens=attn_metadata.total_tokens,
)
_, num_heads, head_size = query.shape _, num_heads, head_size = query.shape
_PARTITION_SIZE_ROCM = 256 nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
num_seqs = seqused_k.shape[0] num_seqs = seqused_k.shape[0]
nbyes_per_qo_elem = torch.finfo(output.dtype).bits // 8
max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM - max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
1) // _PARTITION_SIZE_ROCM 1) // _PARTITION_SIZE_ROCM
workspace_buffer = torch.empty( workspace_buffer = torch.empty(
(num_seqs * num_heads * max_num_partitions * head_size) * (num_seqs * num_heads * max_num_partitions * head_size) *
nbyes_per_qo_elem + 2 * nbytes_per_qo_elem + 2 *
(num_seqs * num_heads * max_num_partitions) * 4, (num_seqs * num_heads * max_num_partitions) * 4,
dtype=torch.uint8, dtype=torch.uint8,
device=output.device, device=output.device,
) )
aiter.paged_attention_v1( torch.ops.aiter.paged_attention_v1(
output[:num_actual_tokens], output[:num_actual_tokens],
workspace_buffer, workspace_buffer,
query[:num_actual_tokens], query[:num_actual_tokens],