Merge ROCm:rope_kvcache_fusion

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg 2025-12-18 16:16:19 +00:00
parent cee749f6f4
commit 25b5317ce4
6 changed files with 239 additions and 59 deletions

View File

@ -142,6 +142,7 @@ class Attention(nn.Module, AttentionLayerBase):
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
rotary_emb: nn.Module | None = None,
**extra_impl_args,
) -> None:
"""
@ -207,7 +208,6 @@ class Attention(nn.Module, AttentionLayerBase):
)
else:
self.attn_backend = attn_backend
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
@ -240,6 +240,7 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name,
**extra_impl_args,
)
self.impl.rotary_emb = rotary_emb
backend_name = self.attn_backend.get_name()
self.backend = AttentionBackendEnum.__members__.get(backend_name)
self.dtype = dtype
@ -296,6 +297,7 @@ class Attention(nn.Module, AttentionLayerBase):
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: torch.Size | None = None,
positions: torch.Tensor | None = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
@ -345,7 +347,7 @@ class Attention(nn.Module, AttentionLayerBase):
)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name
query, key, value, output, self.layer_name, positions=positions
)
return output.view(-1, hidden_size)
else:
@ -564,6 +566,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
rotary_emb: nn.Module | None = None,
**extra_impl_args,
):
super().__init__()
@ -641,6 +644,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
**extra_impl_args,
)
self.impl.rotary_emb = rotary_emb
self.use_direct_call = not current_platform.opaque_attention_op()
compilation_config = get_current_vllm_config().compilation_config
@ -668,6 +673,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
output_shape: torch.Size | None = None,
positions: torch.Tensor | None = None,
) -> torch.Tensor:
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
@ -704,6 +710,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe,
output,
self.layer_name,
positions=positions,
)
return output
else:
@ -858,8 +865,24 @@ def unified_attention_with_output(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
if positions is not None:
assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
positions=positions,
)
return
self.impl.forward(
self,
query,
@ -881,6 +904,7 @@ def unified_attention_with_output_fake(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> None:
return
@ -931,6 +955,7 @@ def unified_mla_attention_with_output(
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
positions: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:
@ -954,6 +979,7 @@ def unified_mla_attention_with_output_fake(
k_pe: torch.Tensor,
output: torch.Tensor,
layer_name: str,
positions: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> None:

View File

@ -7,6 +7,7 @@ import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
@ -125,6 +126,9 @@ class OAIAttention(nn.Module):
attn_type=AttentionType.DECODER,
prefix=f"{prefix}.attn",
sinks=self.sinks,
rotary_emb=self.rotary_emb
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled()
else None,
)
def forward(
@ -132,9 +136,9 @@ class OAIAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
# q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
attn_output = self.attn(q, k, v)
attn_output = self.attn(q, k, v, positions=positions)
output, _ = self.o_proj(attn_output)
return output

View File

@ -31,6 +31,7 @@ import torch
from torch import nn
from transformers import LlamaConfig
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
@ -55,6 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
@ -219,6 +221,9 @@ class LlamaAttention(nn.Module):
per_layer_sliding_window=sliding_window,
attn_type=attn_type,
prefix=f"{prefix}.attn",
rotary_emb=self.rotary_emb
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled()
else None,
)
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
@ -239,11 +244,11 @@ class LlamaAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
# q, k = self.rotary_emb(positions, q, k)
if self.do_llama_4_scaling:
attn_scale = self._get_llama_4_attn_scale(positions)
q = (q * attn_scale).to(q.dtype)
attn_output = self.attn(q, k, v)
attn_output = self.attn(q, k, v, positions=positions)
output, _ = self.o_proj(attn_output)
return output

View File

@ -42,6 +42,11 @@ if current_platform.is_rocm():
def num_programs(total_tokens):
return min(total_tokens, get_cu_count())
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_enabled():
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
@triton.jit
def cp_mha_gather_cache_kernel(
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
@ -782,6 +787,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.
@ -823,29 +829,72 @@ class AiterFlashAttentionImpl(AttentionImpl):
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
if (
positions is not None
and query.shape[0] <= 256
and rocm_aiter_ops.is_enabled()
):
assert self.kv_sharing_target_layer_name is None
cos_sin_cache = self.rotary_emb.cos_sin_cache
is_neox = self.rotary_emb.is_neox_style
cos, sin = cos_sin_cache.chunk(2, dim=-1)
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
query, key, key_cache, value_cache, output = (
fused_qk_rope_reshape_and_cache(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
positions,
cos,
sin,
layer._k_scale,
layer._v_scale,
is_neox,
flash_layout=True,
apply_scale=is_fp8_kv_cache,
offs=None,
q_out=query,
k_out=key,
output_zeros=True,
zeros_out=output,
)
)
else:
if positions is not None:
if current_platform.is_rocm():
query, key = self.rotary_emb.forward_cuda(positions, query, key)
else:
query, key = self.rotary_emb(positions, query, key)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())

View File

@ -31,6 +31,12 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
if current_platform.is_rocm():
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_enabled():
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
@dataclass
class RocmAttentionMetadata:
@ -264,6 +270,7 @@ class RocmAttentionImpl(AttentionImpl):
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -306,19 +313,61 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size
)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
if (
positions is not None
and query.shape[0] <= 256
and rocm_aiter_ops.is_enabled()
):
assert self.kv_sharing_target_layer_name is None
cos_sin_cache = self.rotary_emb.cos_sin_cache
is_neox = self.rotary_emb.is_neox_style
cos, sin = cos_sin_cache.chunk(2, dim=-1)
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
query, key, key_cache, value_cache, output = (
fused_qk_rope_reshape_and_cache(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
positions,
cos,
sin,
layer._k_scale,
layer._v_scale,
is_neox,
flash_layout=False,
apply_scale=is_fp8_kv_cache,
offs=None,
q_out=query,
k_out=key,
output_zeros=True,
zeros_out=output,
)
)
else:
if positions is not None:
if current_platform.is_rocm():
query, key = self.rotary_emb.forward_cuda(positions, query, key)
else:
query, key = self.rotary_emb(positions, query, key)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
PagedAttention.write_to_paged_cache(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import ClassVar
import torch
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
from vllm.attention.backends.abstract import (
AttentionBackend,
@ -379,6 +380,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_metadata: TritonAttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with Paged Attention impl. in Triton.
@ -418,30 +420,75 @@ class TritonAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(1)
# positions is not None entails that Q and K are not RoPE embedded yet,
# therefore, either fused_qk_rope_reshape_and_cache or self.rotary_emb is called
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
if positions is not None and query.shape[0] <= 256:
assert self.kv_sharing_target_layer_name is None, (
"self.kv_sharing_target_layer_name cannot be None"
)
assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}"
cos_sin_cache = self.rotary_emb.cos_sin_cache
is_neox = self.rotary_emb.is_neox_style
cos, sin = cos_sin_cache.chunk(2, dim=-1)
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache_og_dtype = key_cache.dtype
value_cache_og_dtype = value_cache.dtype
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
query, key, key_cache, value_cache, output = (
fused_qk_rope_reshape_and_cache(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
positions,
cos,
sin,
layer._k_scale,
layer._v_scale,
is_neox,
flash_layout=True,
apply_scale=is_fp8_kv_cache,
offs=None,
q_out=query,
k_out=key,
output_zeros=True,
zeros_out=output,
)
)
if is_fp8_kv_cache:
key_cache = key_cache.view(key_cache_og_dtype)
value_cache = value_cache.view(value_cache_og_dtype)
else:
if positions is not None:
query, key = self.rotary_emb(positions, query, key)
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache
# (because some explicit casts (e.g. float8_e4m3fnuz)
# are not supported)
triton_reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
if key_cache.dtype != self.fp8_dtype: