From 8c363ed6663f69b97c9f34b0be0091d8135f958c Mon Sep 17 00:00:00 2001 From: Pleaplusone Date: Sun, 30 Nov 2025 19:31:50 +0800 Subject: [PATCH] [ROCm][Attention] Sliding window support for `AiterFlashAttentionBackend` (#29234) Signed-off-by: ganyi --- vllm/v1/attention/backends/rocm_aiter_fa.py | 273 ++++++++++++++++---- 1 file changed, 224 insertions(+), 49 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index ea911af3d19ce..b6aa0ae2be48e 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -13,8 +13,9 @@ from vllm.attention.backends.abstract import ( AttentionType, MultipleOf, ) +from vllm.attention.layer import Attention from vllm.attention.ops.merge_attn_states import merge_attn_states -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv @@ -57,58 +58,55 @@ if current_platform.is_rocm(): head_size, x, max_block_num, - num_tokens, - num_programs, DEQUANT: tl.constexpr, PAGE_SIZE: tl.constexpr, CACHE_FORMAT: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): - bid = tl.program_id(0) + token_id = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) if DEQUANT: k_scale = tl.load(k_scale_ptr) v_scale = tl.load(v_scale_ptr) - for token_id in tl.range(bid, num_tokens, num_programs): - key_ptr_offset = key_ptr + token_id * head_size * num_heads - value_ptr_offset = value_ptr + token_id * head_size * num_heads - batch_idx = tl.load(token_to_batch_ptr + token_id) - batch_start = tl.load(seq_start_ptr + batch_idx) - token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) - batch_offset = token_id - token_start + batch_start - block_offset = batch_offset // PAGE_SIZE - block_id = tl.load( - block_table_ptr + max_block_num * batch_idx + block_offset + key_ptr_offset = key_ptr + token_id * head_size * num_heads + value_ptr_offset = value_ptr + token_id * head_size * num_heads + batch_idx = tl.load(token_to_batch_ptr + token_id) + batch_start = tl.load(seq_start_ptr + batch_idx) + token_start = tl.load(cu_seqlens_kv_ptr + batch_idx) + batch_offset = token_id - token_start + batch_start + block_offset = batch_offset // PAGE_SIZE + block_id = tl.load( + block_table_ptr + max_block_num * batch_idx + block_offset + ).to(tl.int64) + slot_id = batch_offset % PAGE_SIZE + + if CACHE_FORMAT == "NHD": + # for kv cache layout as + # K: [num_blocks, page_size, num_head, head_dim] + # V: [num_blocks, page_size, num_head, head_dim] + key_cache_ptr_offset = ( + key_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size + ) + value_cache_ptr_offset = ( + value_cache_ptr + + block_id * num_heads * head_size * PAGE_SIZE + + slot_id * num_heads * head_size ) - slot_id = batch_offset % PAGE_SIZE - if CACHE_FORMAT == "NHD": - # for kv cache layout as - # K: [num_blocks, page_size, num_head, head_dim] - # V: [num_blocks, page_size, num_head, head_dim] - key_cache_ptr_offset = ( - key_cache_ptr - + block_id * num_heads * head_size * PAGE_SIZE - + slot_id * num_heads * head_size - ) - value_cache_ptr_offset = ( - value_cache_ptr - + block_id * num_heads * head_size * PAGE_SIZE - + slot_id * num_heads * head_size - ) - - for i in tl.range(0, head_size * num_heads, BLOCK_SIZE): - mask = (col_offsets + i) < head_size * num_heads - k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask) - v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask) - if DEQUANT: - k_dtype = k_reg.dtype - v_dtype = v_reg.dtype - k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) - v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) - tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask) - tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask) + for i in tl.range(0, head_size * num_heads, BLOCK_SIZE): + mask = (col_offsets + i) < head_size * num_heads + k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask) + v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask) + if DEQUANT: + k_dtype = k_reg.dtype + v_dtype = v_reg.dtype + k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype) + v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype) + tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask) + tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask) def cp_mha_gather_cache( key_cache: torch.Tensor, @@ -143,9 +141,7 @@ if current_platform.is_rocm(): page_size = key_cache.shape[1] num_heads = key_cache.shape[2] - NUM_PRGMS = num_programs(total_tokens) - BLOCK_SIZE = block_size(key_cache, head_dim) - grid = lambda meta: (NUM_PRGMS,) + grid = lambda meta: (total_tokens,) cp_mha_gather_cache_kernel[grid]( key_cache, value_cache, @@ -161,12 +157,10 @@ if current_platform.is_rocm(): head_dim, x, block_tables.size(1), - total_tokens, - NUM_PRGMS, DEQUANT=dequant, PAGE_SIZE=page_size, CACHE_FORMAT=kv_cache_layout, - BLOCK_SIZE=BLOCK_SIZE, + BLOCK_SIZE=head_dim, ) @@ -189,6 +183,17 @@ class AiterFlashAttentionPrefillMetadata: query_start_loc: torch.Tensor +@dataclass +class AiterChunkSlidingWindowMetadata: + swa_seqlens: torch.Tensor + swa_cu_seqlens: torch.Tensor + swa_seq_starts: torch.Tensor + swa_token_to_batch: torch.Tensor + swa_max_seqlens: int + swa_total_tokens: int + swa_workspace: torch.Tensor + + @dataclass class AiterChunkContextMetadata: workspace: torch.Tensor @@ -200,6 +205,7 @@ class AiterChunkContextMetadata: seq_lens: torch.Tensor num_chunks: int total_token_per_batch: list[int] + swa_metadata: AiterChunkSlidingWindowMetadata | None @dataclass @@ -278,6 +284,20 @@ class AiterFlashAttentionMetadataBuilder( self.aot_sliding_window: tuple[int, int] | None = None self.total_tokens: int = 0 + sliding_window_configs: set[tuple[int, int] | None] = set() + layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer in layers.values(): + assert isinstance(layer.impl, AiterFlashAttentionImpl) + sliding_window_configs.add(layer.impl.sliding_window) + + while len(sliding_window_configs) > 0: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None and sliding_window_config[0] != -1: + assert self.aot_sliding_window is None, ( + "Aiter Flash ATTENTION can only support one valid sliding window!" + ) + self.aot_sliding_window = sliding_window_config + self.extend_workspace = torch.empty( [2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim], dtype=self.model_config.dtype, @@ -349,6 +369,55 @@ class AiterFlashAttentionMetadataBuilder( query_lens_for_extend = query_lens_cpu[num_extends_slice] seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice] computed_kv_lens = seq_lens_for_extend - query_lens_for_extend + swa_metadata = None + if self.aot_sliding_window is not None: + swa_seqlen_for_extend = torch.minimum( + seq_lens_for_extend, + query_lens_for_extend + self.aot_sliding_window[0] + 1, + ) + cu_seq_lens = torch.zeros( + num_extends + 1, + dtype=torch.int32, + device=seq_lens_for_extend.device, + ) + torch.cumsum( + swa_seqlen_for_extend, + dim=0, + dtype=cu_seq_lens.dtype, + out=cu_seq_lens[1:], + ) + token_to_seq = torch.arange( + 0, + num_extends, + dtype=torch.int32, + device=seq_lens_for_extend.device, + ) + token_to_seq = torch.repeat_interleave( + token_to_seq, swa_seqlen_for_extend + ) + fetched_shape = cu_seq_lens[-1].item() + # TODO(ganyi): Maybe reuse these 2 buffer from extend_workspace + swa_workspace = torch.empty( + (2, fetched_shape, self.num_heads_kv, self.headdim), + dtype=self.vllm_config.model_config.dtype, + device=self.device, + ) + + seq_starts = seq_lens_for_extend - swa_seqlen_for_extend + max_seqlen_k = swa_seqlen_for_extend.max().item() + total_tokens = cu_seq_lens[-1].item() + + swa_metadata = AiterChunkSlidingWindowMetadata( + swa_seqlens=swa_seqlen_for_extend.to( + self.device, non_blocking=True + ), + swa_cu_seqlens=cu_seq_lens.to(self.device, non_blocking=True), + swa_seq_starts=seq_starts.to(self.device, non_blocking=True), + swa_token_to_batch=token_to_seq.to(self.device, non_blocking=True), + swa_max_seqlens=max_seqlen_k, + swa_total_tokens=total_tokens, + swa_workspace=swa_workspace, + ) # allocate the equal amount of workspace for # each chunk prefill request @@ -392,6 +461,7 @@ class AiterFlashAttentionMetadataBuilder( token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True), num_chunks=num_chunks, total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(), + swa_metadata=swa_metadata, ) query_start_loc_device = common_attn_metadata.query_start_loc[ @@ -504,9 +574,9 @@ class AiterFlashAttentionImpl(AttentionImpl): alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes if sliding_window is None: - self.sliding_window = [-1, -1] + self.sliding_window = (-1, -1) else: - self.sliding_window = [sliding_window - 1, 0] + self.sliding_window = (sliding_window - 1, 0) self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. @@ -522,6 +592,67 @@ class AiterFlashAttentionImpl(AttentionImpl): "Encoder self-attention is not implemented for FlashAttentionImpl" ) + def extend_for_sliding_window( + self, + attn_metadata: AiterFlashAttentionMetadata, + query: torch.Tensor, + key_cache, + value_cache, + output: torch.Tensor, + cu_seqlens_q: torch.Tensor, + max_seqlen_q: int, + block_table: torch.Tensor, + k_scale: float, + v_scale: float, + ): + assert attn_metadata.extend_metadata is not None + assert attn_metadata.extend_metadata.chunk_context_metadata is not None + chunked_metadata = attn_metadata.extend_metadata.chunk_context_metadata + swa_metadata = chunked_metadata.swa_metadata + assert swa_metadata is not None + swa_cu_seqlens = swa_metadata.swa_cu_seqlens + swa_seq_starts = swa_metadata.swa_seq_starts + swa_token_to_batch = swa_metadata.swa_token_to_batch + swa_max_seqlens = swa_metadata.swa_max_seqlens + swa_total_tokens = swa_metadata.swa_total_tokens + key_fetched, value_fetched = ( + swa_metadata.swa_workspace[0], + swa_metadata.swa_workspace[1], + ) + cp_mha_gather_cache( + key_cache=key_cache, + value_cache=value_cache, + key=key_fetched, + value=value_fetched, + block_tables=block_table, + k_scales=k_scale, + v_scales=v_scale, + cu_seqlens_kv=swa_cu_seqlens, + token_to_batch=swa_token_to_batch, + seq_starts=swa_seq_starts, + dequant=False, + kv_cache_layout="NHD", + total_tokens=swa_total_tokens, + ) + + aiter.flash_attn_varlen_func( + q=query, + k=key_fetched, + v=value_fetched, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=swa_cu_seqlens, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=swa_max_seqlens, + min_seqlen_q=1, + dropout_p=0.0, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + return_lse=False, + out=output, + ) + def extend_forward( self, attn_metadata: AiterFlashAttentionMetadata, @@ -540,6 +671,20 @@ class AiterFlashAttentionImpl(AttentionImpl): k_scale: float, v_scale: float, ): + if self.sliding_window[0] != -1: + self.extend_for_sliding_window( + attn_metadata, + query, + key_cache, + value_cache, + output, + cu_seqlens_q, + max_seqlen_q, + block_table, + k_scale, + v_scale, + ) + return out, lse = aiter.flash_attn_varlen_func( q=query, k=key, @@ -782,6 +927,36 @@ class AiterFlashAttentionImpl(AttentionImpl): # calculate for decodes if num_decodes > 0: assert attn_metadata.decode_metadata is not None + if self.sliding_window[0] != -1: + from aiter.ops.triton.unified_attention import ( + unified_attention, + ) + + descale_shape = ( + attn_metadata.query_start_loc[:num_decodes].shape[0] - 1, + key_cache.shape[2], + ) + unified_attention( + q=query[:num_decode_tokens], + k=key_cache, + v=value_cache, + out=output[:num_decode_tokens], + cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes], + max_seqlen_q=1, # optimize this + seqused_k=attn_metadata.seq_lens[:num_decodes], + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=attn_metadata.block_table[:num_decodes], + softcap=self.logits_soft_cap, + q_descale=None, + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + ) + return + assert attn_metadata.decode_metadata is not None _, num_heads, head_size = query.shape nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8 num_seqs = attn_metadata.seq_lens.shape[0]