From 25b5317ce46e3f4ac8a7ed548106291fae63c5c0 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 18 Dec 2025 16:16:19 +0000 Subject: [PATCH] Merge ROCm:rope_kvcache_fusion Signed-off-by: Gregory Shtrasberg --- vllm/attention/layer.py | 30 ++++++- vllm/model_executor/models/gpt_oss.py | 8 +- vllm/model_executor/models/llama.py | 9 +- vllm/v1/attention/backends/rocm_aiter_fa.py | 91 ++++++++++++++++----- vllm/v1/attention/backends/rocm_attn.py | 73 ++++++++++++++--- vllm/v1/attention/backends/triton_attn.py | 87 +++++++++++++++----- 6 files changed, 239 insertions(+), 59 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 7ef77db8fbb5b..108ce7176b965 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -142,6 +142,7 @@ class Attention(nn.Module, AttentionLayerBase): attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, + rotary_emb: nn.Module | None = None, **extra_impl_args, ) -> None: """ @@ -207,7 +208,6 @@ class Attention(nn.Module, AttentionLayerBase): ) else: self.attn_backend = attn_backend - # prefix caching + batch invariance is currently not supported for # FLASHINFER and TRITON_MLA. if ( @@ -240,6 +240,7 @@ class Attention(nn.Module, AttentionLayerBase): kv_sharing_target_layer_name, **extra_impl_args, ) + self.impl.rotary_emb = rotary_emb backend_name = self.attn_backend.get_name() self.backend = AttentionBackendEnum.__members__.get(backend_name) self.dtype = dtype @@ -296,6 +297,7 @@ class Attention(nn.Module, AttentionLayerBase): # shape does not match the query shape, so we optionally let the model # definition specify the output tensor shape. output_shape: torch.Size | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """ The KV cache is stored inside this class and is accessed via @@ -345,7 +347,7 @@ class Attention(nn.Module, AttentionLayerBase): ) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name + query, key, value, output, self.layer_name, positions=positions ) return output.view(-1, hidden_size) else: @@ -564,6 +566,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + rotary_emb: nn.Module | None = None, **extra_impl_args, ): super().__init__() @@ -641,6 +644,8 @@ class MLAAttention(nn.Module, AttentionLayerBase): **extra_impl_args, ) + self.impl.rotary_emb = rotary_emb + self.use_direct_call = not current_platform.opaque_attention_op() compilation_config = get_current_vllm_config().compilation_config @@ -668,6 +673,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): kv_c_normed: torch.Tensor, k_pe: torch.Tensor, output_shape: torch.Size | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) @@ -704,6 +710,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): k_pe, output, self.layer_name, + positions=positions, ) return output else: @@ -858,8 +865,24 @@ def unified_attention_with_output( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> None: attn_metadata, self, kv_cache = get_attention_context(layer_name) + if positions is not None: + assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None + self.impl.forward( + self, + query, + key, + value, + kv_cache, + attn_metadata, + output=output, + output_scale=output_scale, + output_block_scale=output_block_scale, + positions=positions, + ) + return self.impl.forward( self, query, @@ -881,6 +904,7 @@ def unified_attention_with_output_fake( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> None: return @@ -931,6 +955,7 @@ def unified_mla_attention_with_output( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: @@ -954,6 +979,7 @@ def unified_mla_attention_with_output_fake( k_pe: torch.Tensor, output: torch.Tensor, layer_name: str, + positions: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> None: diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 6a92cf1533213..a423d54a31889 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -7,6 +7,7 @@ import torch.distributed as dist from torch import nn from transformers import GptOssConfig +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile @@ -125,6 +126,9 @@ class OAIAttention(nn.Module): attn_type=AttentionType.DECODER, prefix=f"{prefix}.attn", sinks=self.sinks, + rotary_emb=self.rotary_emb + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled() + else None, ) def forward( @@ -132,9 +136,9 @@ class OAIAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) + # q, k = self.rotary_emb(positions, q, k) v = v.contiguous() - attn_output = self.attn(q, k, v) + attn_output = self.attn(q, k, v, positions=positions) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3507a2bc66c17..d978ab3c12486 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -31,6 +31,7 @@ import torch from torch import nn from transformers import LlamaConfig +from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention @@ -55,6 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP @@ -219,6 +221,9 @@ class LlamaAttention(nn.Module): per_layer_sliding_window=sliding_window, attn_type=attn_type, prefix=f"{prefix}.attn", + rotary_emb=self.rotary_emb + if current_platform.is_rocm() and rocm_aiter_ops.is_enabled() + else None, ) def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: @@ -239,11 +244,11 @@ class LlamaAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) + # q, k = self.rotary_emb(positions, q, k) if self.do_llama_4_scaling: attn_scale = self._get_llama_4_attn_scale(positions) q = (q * attn_scale).to(q.dtype) - attn_output = self.attn(q, k, v) + attn_output = self.attn(q, k, v, positions=positions) output, _ = self.o_proj(attn_output) return output diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index b6aa0ae2be48e..e9cdc8b715d9b 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -42,6 +42,11 @@ if current_platform.is_rocm(): def num_programs(total_tokens): return min(total_tokens, get_cu_count()) + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache + @triton.jit def cp_mha_gather_cache_kernel( key_cache_ptr, # [num_blocks, page_size, num_head, head_size] @@ -782,6 +787,7 @@ class AiterFlashAttentionImpl(AttentionImpl): output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -823,29 +829,72 @@ class AiterFlashAttentionImpl(AttentionImpl): # key and value may be None in the case of cross attention. They are # calculated once based on the output from the encoder and then cached # in KV cache. - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping - # is not padded. However, we don't need to do - # key[:num_actual_tokens] and value[:num_actual_tokens] because - # the reshape_and_cache_flash op uses the slot_mapping's shape - # to determine the number of actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + if ( + positions is not None + and query.shape[0] <= 256 + and rocm_aiter_ops.is_enabled() + ): + assert self.kv_sharing_target_layer_name is None + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(current_platform.fp8_dtype()) + value_cache = value_cache.view(current_platform.fp8_dtype()) + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=True, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + else: + if positions is not None: + if current_platform.is_rocm(): + query, key = self.rotary_emb.forward_cuda(positions, query, key) + else: + query, key = self.rotary_emb(positions, query, key) + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping + # is not padded. However, we don't need to do + # key[:num_actual_tokens] and value[:num_actual_tokens] because + # the reshape_and_cache_flash op uses the slot_mapping's shape + # to determine the number of actual tokens. + + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(current_platform.fp8_dtype()) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index e2410a70b1a63..b2113daf925d5 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -31,6 +31,12 @@ from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +if current_platform.is_rocm(): + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache + @dataclass class RocmAttentionMetadata: @@ -264,6 +270,7 @@ class RocmAttentionImpl(AttentionImpl): output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -306,19 +313,61 @@ class RocmAttentionImpl(AttentionImpl): kv_cache, self.num_kv_heads, self.head_size ) - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + if ( + positions is not None + and query.shape[0] <= 256 + and rocm_aiter_ops.is_enabled() + ): + assert self.kv_sharing_target_layer_name is None + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=False, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + else: + if positions is not None: + if current_platform.is_rocm(): + query, key = self.rotary_emb.forward_cuda(positions, query, key) + else: + query, key = self.rotary_emb(positions, query, key) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index ca7be990ca555..4abfc560fb4b9 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import ClassVar import torch +from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache from vllm.attention.backends.abstract import ( AttentionBackend, @@ -379,6 +380,7 @@ class TritonAttentionImpl(AttentionImpl): attn_metadata: TritonAttentionMetadata, output: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, + positions: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, ) -> torch.Tensor: """Forward pass with Paged Attention impl. in Triton. @@ -418,30 +420,75 @@ class TritonAttentionImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(1) + # positions is not None entails that Q and K are not RoPE embedded yet, + # therefore, either fused_qk_rope_reshape_and_cache or self.rotary_emb is called - if ( - self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - if self.kv_cache_dtype.startswith("fp8"): + if positions is not None and query.shape[0] <= 256: + assert self.kv_sharing_target_layer_name is None, ( + "self.kv_sharing_target_layer_name cannot be None" + ) + assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}" + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache_og_dtype = key_cache.dtype + value_cache_og_dtype = value_cache.dtype key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) - # triton kernel does not support uint8 kv_cache - # (because some explicit casts (e.g. float8_e4m3fnuz) - # are not supported) - triton_reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=True, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + if is_fp8_kv_cache: + key_cache = key_cache.view(key_cache_og_dtype) + value_cache = value_cache.view(value_cache_og_dtype) + else: + if positions is not None: + query, key = self.rotary_emb(positions, query, key) + if ( + self.kv_sharing_target_layer_name is None + and key is not None + and value is not None + ): + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + # triton kernel does not support uint8 kv_cache + # (because some explicit casts (e.g. float8_e4m3fnuz) + # are not supported) + triton_reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): if key_cache.dtype != self.fp8_dtype: