[ROCm][Perf] New design on ROCm AITER MHA backend Implementation (#25763)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone 2025-11-05 02:05:33 +08:00 committed by GitHub
parent 2f1cc8cef1
commit dc937175d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 595 additions and 277 deletions

View File

@ -13,223 +13,204 @@ from vllm.attention.backends.abstract import (
AttentionType,
MultipleOf,
)
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_prefills_and_extends,
)
from vllm.v1.kv_cache_interface import AttentionSpec
_PARTITION_SIZE_ROCM = 256
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm():
import aiter
from aiter.ops.triton.utils.device_info import get_num_sms
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
def num_programs(head_dim):
return min(head_dim, get_num_sms())
@triton.jit
def _vllm_layout_trans_kernel(
k_buffer_ptr,
v_buffer_ptr,
k_values_ptr,
v_values_ptr,
b_query_lens_loc,
b_seq_lens_loc,
block_table,
block_table_stride_0,
k_scale,
v_scale,
output_dtype: tl.constexpr,
E_DIM: tl.constexpr,
def cp_mha_gather_cache_kernel(
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
value_cache_ptr, # [num_blocks, page_size, num_head, head_size]
key_ptr, # [num_tokens, num_heads, head_size]
value_ptr, # [num_tokens, num_heads, head_size]
block_table_ptr, # [num_batches, max_block_num]
cu_seqlens_kv_ptr, # [num_batches + 1]
token_to_batch_ptr, # [max_cum_tokens]
seq_start_ptr, # [num_batches]
k_scale_ptr,
v_scale_ptr,
num_heads,
head_size,
x,
max_block_num,
num_tokens,
DEQUANT: tl.constexpr,
PAGE_SIZE: tl.constexpr,
CACHE_FORMAT: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NUM_PRGMS: tl.constexpr,
):
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
bid = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
if DEQUANT:
k_scale = tl.load(k_scale_ptr)
v_scale = tl.load(v_scale_ptr)
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx + tl.arange(0, 2))
batch_query_start, batch_query_end = tl.split(batch_query_indexes)
query_len = batch_query_end - batch_query_start
if query_len <= 1:
return
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx + tl.arange(0, 2))
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
seq_len = batch_token_end - batch_token_start
if block_idx * BLOCK_SIZE < seq_len:
block_mask = (
block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[:, None]
) < seq_len
kv_idx = tl.load(
block_table + batch_idx * block_table_stride_0 + block_idx
).to(tl.int64)
kv_buffer_off = (
kv_idx * BLOCK_SIZE * E_DIM
+ tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM
+ tl.arange(0, E_DIM)[None, :]
for token_id in tl.range(bid, num_tokens, NUM_PRGMS):
key_ptr_offset = key_ptr + token_id * head_size * num_heads
value_ptr_offset = value_ptr + token_id * head_size * num_heads
batch_idx = tl.load(token_to_batch_ptr + token_id)
batch_start = tl.load(seq_start_ptr + batch_idx)
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
batch_offset = token_id - token_start + batch_start
block_offset = batch_offset // PAGE_SIZE
block_id = tl.load(
block_table_ptr + max_block_num * batch_idx + block_offset
)
k_vals = tl.load(k_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0)
if k_vals.dtype.is_fp8():
k_vals = (k_vals.to(tl.float32) * tl.load(k_scale)).to(output_dtype)
else:
k_vals = k_vals.to(output_dtype)
slot_id = batch_offset % PAGE_SIZE
v_vals = tl.load(v_buffer_ptr + kv_buffer_off, mask=block_mask, other=0.0)
if v_vals.dtype.is_fp8():
v_vals = (v_vals.to(tl.float32) * tl.load(v_scale)).to(output_dtype)
else:
v_vals = v_vals.to(output_dtype)
kv_values_off = (
batch_token_start * E_DIM
+ block_idx * BLOCK_SIZE * E_DIM
+ tl.arange(0, BLOCK_SIZE)[:, None] * E_DIM
+ tl.arange(0, E_DIM)[None, :]
)
tl.store(k_values_ptr + kv_values_off, k_vals, mask=block_mask)
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
if CACHE_FORMAT == "NHD":
# for kv cache layout as
# K: [num_blocks, page_size, num_head, head_dim]
# V: [num_blocks, page_size, num_head, head_dim]
key_cache_ptr_offset = (
key_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE
+ slot_id * num_heads * head_size
)
value_cache_ptr_offset = (
value_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE
+ slot_id * num_heads * head_size
)
def vllm_layout_trans(
b_query_lens_loc,
b_seq_lens_loc,
block_table,
k_cache,
v_cache,
max_seq_len,
k_scale,
v_scale,
output_dtype,
total_tokens,
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
mask = (col_offsets + i) < head_size * num_heads
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
if DEQUANT:
k_dtype = k_reg.dtype
v_dtype = v_reg.dtype
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
def cp_mha_gather_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
block_tables: torch.Tensor,
k_scales: torch.Tensor,
v_scales: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
token_to_batch: torch.Tensor,
seq_starts: torch.Tensor,
dequant: bool,
kv_cache_layout: str,
total_tokens: int,
):
H_KV = v_cache.shape[2]
D = v_cache.shape[3]
BLOCK_SIZE = v_cache.shape[1]
k_values = torch.empty(
(total_tokens, H_KV, D),
dtype=output_dtype,
device=k_cache.device,
assert kv_cache_layout in ["v0", "NHD", "HND"], (
"kv_cache_layout only support v0, NHD, HND"
)
v_values = torch.empty(
(total_tokens, H_KV, D),
dtype=output_dtype,
device=v_cache.device,
head_dim = key.shape[2]
x = 0
# assert dequant is True, "Currently, we only support "\
# "gather cache with dequant"
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
assert kv_cache_layout == "NHD", (
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
)
grid = (block_table.shape[0], (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
if output_dtype == torch.float16:
output_dtype = tl.float16
elif output_dtype == torch.bfloat16:
output_dtype = tl.bfloat16
else:
raise ValueError(f"Unsupported output dtype: {output_dtype}")
_vllm_layout_trans_kernel[grid](
k_cache,
v_cache,
k_values,
v_values,
b_query_lens_loc,
b_seq_lens_loc,
block_table,
block_table.stride(0),
k_scale,
v_scale,
output_dtype=output_dtype,
E_DIM=H_KV * D,
BLOCK_SIZE=BLOCK_SIZE,
assert head_dim == key_cache.shape[3], (
"We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise"
)
page_size = key_cache.shape[1]
num_heads = key_cache.shape[2]
return k_values, v_values
def flash_attn_varlen_func_impl(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: list[int] | None, # -1 means infinite context window
alibi_slopes: list[float] | None,
block_table: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
total_tokens: int = 0,
) -> torch.Tensor:
if total_tokens == 0:
total_tokens = int(cu_seqlens_k[-1].item())
k, v = vllm_layout_trans(
cu_seqlens_q,
cu_seqlens_k,
block_table,
k_cache,
v_cache,
max_seqlen_k,
k_scale,
v_scale,
q.dtype,
NUM_PRGMS = num_programs(total_tokens)
BLOCK_SIZE = block_size(key_cache, head_dim)
grid = lambda meta: (NUM_PRGMS,)
cp_mha_gather_cache_kernel[grid](
key_cache,
value_cache,
key,
value,
block_tables,
cu_seqlens_kv,
token_to_batch,
seq_starts,
k_scales,
v_scales,
num_heads,
head_dim,
x,
block_tables.size(1),
total_tokens,
DEQUANT=dequant,
PAGE_SIZE=page_size,
CACHE_FORMAT=kv_cache_layout,
BLOCK_SIZE=BLOCK_SIZE,
NUM_PRGMS=NUM_PRGMS,
)
output = aiter.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
min_seqlen_q=1,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
alibi_slopes=alibi_slopes,
window_size=window_size,
out=out,
)
return output
def flash_attn_varlen_func_fake(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
out: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
window_size: list[int] | None, # -1 means infinite context window
alibi_slopes: list[float] | None,
block_table: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
total_tokens: int = 0,
) -> torch.Tensor:
return torch.empty(
q.shape[0], q.shape[1], v_cache.shape[-2], dtype=q.dtype, device=q.device
)
direct_register_custom_op(
"flash_attn_varlen_func",
flash_attn_varlen_func_impl,
["out"],
flash_attn_varlen_func_fake,
dispatch_key=current_platform.dispatch_key,
)
logger = init_logger(__name__)
@dataclass
class AiterFlashAttentionDecodeMetadata:
max_query_len: int
min_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor
@dataclass
class AiterFlashAttentionPrefillMetadata:
max_query_len: int
min_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor
@dataclass
class AiterChunkContextMetadata:
workspace: torch.Tensor
cu_seq_lens_chunk: torch.Tensor
chunk_starts: torch.Tensor
token_to_batch: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
seq_lens: torch.Tensor
num_chunks: int
total_token_per_batch: list[int]
@dataclass
class AiterFlashAttentionChunkPrefillMetadata:
max_query_len: int
min_query_len: int
max_seq_len: int
query_start_loc: torch.Tensor
chunk_context_metadata: AiterChunkContextMetadata
@dataclass
class AiterFlashAttentionMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
@ -248,7 +229,18 @@ class AiterFlashAttentionMetadata:
seq_lens: torch.Tensor
slot_mapping: torch.Tensor
block_table: torch.Tensor
cu_seq_lens: torch.Tensor | None
# prefill and deocde split
num_decodes: int
num_decode_tokens: int
num_prefills: int
num_prefill_tokens: int
num_extends: int
num_extend_tokens: int
decode_metadata: AiterFlashAttentionDecodeMetadata | None
prefill_metadata: AiterFlashAttentionPrefillMetadata | None
extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None
# For cascade attention.
use_cascade: bool
@ -260,6 +252,7 @@ class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
):
cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: int = 1
def __init__(
self,
@ -285,6 +278,12 @@ class AiterFlashAttentionMetadataBuilder(
self.aot_sliding_window: tuple[int, int] | None = None
self.total_tokens: int = 0
self.extend_workspace = torch.empty(
[2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
dtype=self.model_config.dtype,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
):
@ -302,42 +301,139 @@ class AiterFlashAttentionMetadataBuilder(
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> "AiterFlashAttentionMetadata":
num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
max_seq_len = common_attn_metadata.max_seq_len
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
if max_query_len > 1:
# We pre-compute cumulative seq len needed for prefill attention
# here to avoid recomputing it for every layer
cu_seq_lens = torch.zeros(
seq_lens.shape[0] + 1, dtype=torch.int32, device=seq_lens.device
)
torch.cumsum(seq_lens, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:])
num_actual_kv_tokens = int(cu_seq_lens[-1].item())
else:
cu_seq_lens = None
num_actual_kv_tokens = 0
split_ret = split_decodes_prefills_and_extends(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold,
)
def schedule(
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
return None
(
num_decodes,
num_extends,
num_prefills,
num_decode_tokens,
num_extend_tokens,
num_prefill_tokens,
) = split_ret
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
decode_metadata = None
if num_decodes > 0:
decode_metadata = AiterFlashAttentionDecodeMetadata(
max_query_len=query_lens_cpu[:num_decodes].max().item(),
min_query_len=query_lens_cpu[:num_decodes].min().item(),
max_seq_len=seq_lens[:num_decodes].max().item(),
query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1],
)
prefill_metadata = None
if num_prefills > 0:
query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :]
query_start_loc_device = common_attn_metadata.query_start_loc[
num_decodes + num_extends :
]
prefill_metadata = AiterFlashAttentionPrefillMetadata(
max_query_len=query_lens_for_prefill.max().item(),
min_query_len=query_lens_for_prefill.min().item(),
max_seq_len=seq_lens[num_decodes + num_extends :].max().item(),
query_start_loc=query_start_loc_device - query_start_loc_device[0],
)
extend_metadata = None
if num_extends > 0:
num_extends_slice = slice(num_decodes, num_decodes + num_extends)
query_lens_for_extend = query_lens_cpu[num_extends_slice]
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
# allocate the equal amount of workspace for
# each chunk prefill request
max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends
num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk)
chunk_starts = (
torch.arange(num_chunks, dtype=torch.int32)
.unsqueeze(1)
.expand(-1, num_extends)
* max_context_chunk
)
chunk_ends = torch.min(
computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk
)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(
min=0
) # [num_chunks, num_extends]
cu_seq_lens_cpu = torch.zeros(
[num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True
)
torch.cumsum(
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
)
max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item()
range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :]
idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None]
idx_to_batch_tensor = idx_to_batch_tensor.sum(
dim=1
) # [num_chunks, max_cum_tokens]
token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1)
chunk_context_metadata = AiterChunkContextMetadata(
workspace=self.extend_workspace,
cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True),
chunk_starts=chunk_starts.to(self.device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
seq_lens=chunk_seq_lens,
token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
num_chunks=num_chunks,
total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
)
query_start_loc_device = common_attn_metadata.query_start_loc[
num_decodes : num_decodes + num_extends + 1
]
seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice]
cu_seq_lens = torch.zeros(
num_extends + 1, dtype=torch.int32, device=seq_lens_device.device
)
torch.cumsum(
seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]
)
extend_metadata = AiterFlashAttentionChunkPrefillMetadata(
max_query_len=query_lens_for_extend.max().item(),
min_query_len=query_lens_for_extend.min().item(),
max_seq_len=seq_lens[num_extends_slice].max().item(),
query_start_loc=query_start_loc_device - query_start_loc_device[0],
chunk_context_metadata=chunk_context_metadata,
)
num_actual_kv_tokens = torch.sum(seq_lens).item()
use_cascade = common_prefix_len > 0
attn_metadata = AiterFlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
num_actual_kv_tokens=num_actual_kv_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table_tensor,
slot_mapping=slot_mapping,
cu_seq_lens=cu_seq_lens,
max_query_len=common_attn_metadata.max_query_len,
query_start_loc=common_attn_metadata.query_start_loc,
max_seq_len=common_attn_metadata.max_seq_len,
seq_lens=common_attn_metadata.seq_lens,
block_table=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_extends=num_extends,
num_extend_tokens=num_extend_tokens,
decode_metadata=decode_metadata,
prefill_metadata=prefill_metadata,
extend_metadata=extend_metadata,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
total_tokens=self.total_tokens,
@ -401,6 +497,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@ -449,6 +546,110 @@ class AiterFlashAttentionImpl(AttentionImpl):
"FlashAttentionImpl"
)
def extend_forward(
self,
attn_metadata: AiterFlashAttentionMetadata,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
output: torch.Tensor,
cu_seqlens_q: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
min_seqlen_q: int,
block_table: torch.Tensor,
slot_mapping: torch.Tensor,
k_scale: float,
v_scale: float,
):
out, lse = aiter.flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
min_seqlen_q=min_seqlen_q,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
return_lse=True,
)
assert attn_metadata.extend_metadata is not None
chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata
num_chunks = chunk_context_metadata.num_chunks
workspace = chunk_context_metadata.workspace
cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk
max_seqlens = chunk_context_metadata.max_seq_lens
chunk_starts = chunk_context_metadata.chunk_starts
token_to_batch = chunk_context_metadata.token_to_batch
total_token_per_batch = chunk_context_metadata.total_token_per_batch
key_fetched, value_fetched = workspace[0], workspace[1]
chunked_output = None
chunked_lse = None
for chunk_idx in range(num_chunks):
cp_mha_gather_cache(
key_cache=key_cache,
value_cache=value_cache,
key=key_fetched,
value=value_fetched,
block_tables=block_table,
k_scales=k_scale,
v_scales=v_scale,
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
token_to_batch=token_to_batch[chunk_idx],
seq_starts=chunk_starts[chunk_idx],
dequant=False,
kv_cache_layout="NHD",
total_tokens=total_token_per_batch[chunk_idx],
)
suf_out, suf_lse = aiter.flash_attn_varlen_func(
q=query,
k=key_fetched,
v=value_fetched,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_kv[chunk_idx],
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlens[chunk_idx],
min_seqlen_q=min_seqlen_q,
dropout_p=0.0,
softmax_scale=self.scale,
causal=False,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
return_lse=True,
)
if chunked_output is None:
chunked_output = suf_out
chunked_lse = suf_lse
else:
tmp_output = torch.empty_like(out)
tmp_lse = torch.empty_like(lse)
merge_attn_states(
output=tmp_output,
output_lse=tmp_lse,
prefix_output=chunked_output,
prefix_lse=chunked_lse,
suffix_output=suf_out,
suffix_lse=suf_lse,
)
chunked_output = tmp_output
chunked_lse = tmp_lse
merge_attn_states(
output=output,
prefix_output=chunked_output,
prefix_lse=chunked_lse,
suffix_output=out,
suffix_lse=lse,
)
def forward(
self,
layer: torch.nn.Module,
@ -488,24 +689,25 @@ class AiterFlashAttentionImpl(AttentionImpl):
return output.fill_(0)
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# NOTE(woosuk): With piece-wise CUDA graphs, this method is
# executed in eager-mode PyTorch. Thus, we need to be careful
# about any CPU overhead in this method. For example, `view`
# and `slice` (or `[:n]`) operations are surprisingly slow even
# in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
@ -521,69 +723,118 @@ class AiterFlashAttentionImpl(AttentionImpl):
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
if not attn_metadata.use_cascade:
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# decode:extend:prefill
query = query[:num_actual_tokens]
key = key[:num_actual_tokens]
value = value[:num_actual_tokens]
if max_seqlen_q > 1:
torch.ops.vllm.flash_attn_varlen_func(
query[:num_actual_tokens],
key_cache,
value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
output_actual_tokens = output[:num_actual_tokens]
num_decodes = attn_metadata.num_decodes
num_prefills = attn_metadata.num_prefills
num_extends = attn_metadata.num_extends
num_decode_tokens = attn_metadata.num_decode_tokens
num_extend_tokens = attn_metadata.num_extend_tokens
if not attn_metadata.use_cascade:
# calculate for pure prefills
if num_prefills > 0:
assert attn_metadata.prefill_metadata is not None
prefill_query = query[num_decode_tokens + num_extend_tokens :]
prefill_key = key[num_decode_tokens + num_extend_tokens :]
prefill_value = value[num_decode_tokens + num_extend_tokens :]
aiter.flash_attn_varlen_func(
q=prefill_query,
k=prefill_key,
v=prefill_value,
cu_seqlens_q=attn_metadata.prefill_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
min_seqlen_q=attn_metadata.prefill_metadata.min_query_len,
dropout_p=0.0,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
causal=True,
window_size=self.sliding_window,
block_table=block_table,
cu_seqlens_k=attn_metadata.cu_seq_lens,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
total_tokens=attn_metadata.num_actual_kv_tokens,
alibi_slopes=self.alibi_slopes,
out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
)
_, num_heads, head_size = query.shape
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
num_seqs = seqused_k.shape[0]
max_num_partitions = (
max_seqlen_k + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
# calculate for extends
if num_extends > 0:
assert attn_metadata.extend_metadata is not None
extend_tokens_slice = slice(
num_decode_tokens, num_decode_tokens + num_extend_tokens
)
extend_querys = query[extend_tokens_slice]
extend_keys = key[extend_tokens_slice]
extend_values = value[extend_tokens_slice]
extend_outputs = output[extend_tokens_slice]
self.extend_forward(
attn_metadata=attn_metadata,
query=extend_querys,
key=extend_keys,
value=extend_values,
key_cache=key_cache,
value_cache=value_cache,
output=extend_outputs,
cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
min_seqlen_q=attn_metadata.extend_metadata.min_query_len,
block_table=attn_metadata.block_table[
num_decodes : num_decodes + num_extends
],
slot_mapping=attn_metadata.slot_mapping[
num_decodes : num_decodes + num_extends
],
k_scale=layer._k_scale,
v_scale=layer._v_scale,
)
workspace_buffer = torch.empty(
(num_seqs * num_heads * max_num_partitions * head_size)
* nbytes_per_qo_elem
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
dtype=torch.uint8,
device=output.device,
)
# calculate for decodes
if num_decodes > 0:
assert attn_metadata.decode_metadata is not None
_, num_heads, head_size = query.shape
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
num_seqs = attn_metadata.seq_lens.shape[0]
max_num_partitions = (
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
) // _PARTITION_SIZE_ROCM
torch.ops.aiter.paged_attention_v1(
output[:num_actual_tokens],
workspace_buffer,
query[:num_actual_tokens],
key_cache,
value_cache,
self.scale,
block_table,
cu_seqlens_q,
seqused_k,
max_seqlen_k,
self.alibi_slopes,
self.kv_cache_dtype,
"NHD",
self.logits_soft_cap,
layer._k_scale,
layer._v_scale,
None,
_PARTITION_SIZE_ROCM,
)
return output
workspace_buffer = torch.empty(
(num_seqs * num_heads * max_num_partitions * head_size)
* nbytes_per_qo_elem
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
dtype=torch.uint8,
device=output.device,
)
torch.ops.aiter.paged_attention_v1(
output[:num_decode_tokens],
workspace_buffer,
query[:num_decode_tokens],
key_cache,
value_cache,
self.scale,
attn_metadata.block_table[:num_decodes],
attn_metadata.query_start_loc[:num_decodes],
attn_metadata.seq_lens[:num_decodes],
attn_metadata.max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
"NHD",
self.logits_soft_cap,
layer._k_scale,
layer._v_scale,
None,
_PARTITION_SIZE_ROCM,
)
else:
raise NotImplementedError(
"Cascade attention is not implemented for ROCM AITER"
)
return output

View File

@ -728,6 +728,73 @@ def subclass_attention_backend(
)
def split_decodes_prefills_and_extends(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
Returns:
num_decodes: The number of decode requests.
num_extends: The number of extend requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_extend_tokens: The number of tokens in the extend requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens_cpu
if max_query_len <= decode_threshold:
return num_reqs, 0, 0, num_tokens, 0, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill_or_extend = query_lens > decode_threshold
is_prefill = (seq_lens == query_lens) & is_prefill_or_extend
first_extend = is_prefill_or_extend.int().argmax(dim=-1).item()
first_prefill = is_prefill.int().argmax(dim=-1).item()
num_decodes = first_extend
num_decode_tokens = query_start_loc[first_extend].item()
if not torch.any(is_prefill_or_extend):
return (num_decodes, 0, 0, num_decode_tokens, 0, 0)
num_prefills_or_extends = num_reqs - num_decodes
num_prefill_or_extend_tokens = num_tokens - num_decode_tokens
if not torch.any(is_prefill):
return (
num_decodes,
num_prefills_or_extends,
0,
num_decode_tokens,
num_prefill_or_extend_tokens,
0,
)
num_extends = first_prefill - num_decodes
num_prefills = num_reqs - first_prefill
num_prefill_tokens = num_tokens - query_start_loc[first_prefill]
num_extend_tokens = num_prefill_or_extend_tokens - num_prefill_tokens
return (
num_decodes,
num_extends,
num_prefills,
num_decode_tokens,
num_extend_tokens,
num_prefill_tokens,
)
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,