mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 23:27:06 +08:00
Merge ROCm:rope_kvcache_fusion
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
cee749f6f4
commit
25b5317ce4
@ -142,6 +142,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
rotary_emb: nn.Module | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
@ -207,7 +208,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
# prefix caching + batch invariance is currently not supported for
|
||||
# FLASHINFER and TRITON_MLA.
|
||||
if (
|
||||
@ -240,6 +240,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.impl.rotary_emb = rotary_emb
|
||||
backend_name = self.attn_backend.get_name()
|
||||
self.backend = AttentionBackendEnum.__members__.get(backend_name)
|
||||
self.dtype = dtype
|
||||
@ -296,6 +297,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: torch.Size | None = None,
|
||||
positions: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
@ -345,7 +347,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
else:
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query, key, value, output, self.layer_name
|
||||
query, key, value, output, self.layer_name, positions=positions
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
@ -564,6 +566,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
prefix: str = "",
|
||||
use_sparse: bool = False,
|
||||
indexer: object | None = None,
|
||||
rotary_emb: nn.Module | None = None,
|
||||
**extra_impl_args,
|
||||
):
|
||||
super().__init__()
|
||||
@ -641,6 +644,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
**extra_impl_args,
|
||||
)
|
||||
|
||||
self.impl.rotary_emb = rotary_emb
|
||||
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
@ -668,6 +673,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
positions: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
|
||||
@ -704,6 +710,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
k_pe,
|
||||
output,
|
||||
self.layer_name,
|
||||
positions=positions,
|
||||
)
|
||||
return output
|
||||
else:
|
||||
@ -858,8 +865,24 @@ def unified_attention_with_output(
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
positions: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||
if positions is not None:
|
||||
assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
positions=positions,
|
||||
)
|
||||
return
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
@ -881,6 +904,7 @@ def unified_attention_with_output_fake(
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
positions: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
@ -931,6 +955,7 @@ def unified_mla_attention_with_output(
|
||||
k_pe: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
positions: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
@ -954,6 +979,7 @@ def unified_mla_attention_with_output_fake(
|
||||
k_pe: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
positions: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
|
||||
@ -7,6 +7,7 @@ import torch.distributed as dist
|
||||
from torch import nn
|
||||
from transformers import GptOssConfig
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
@ -125,6 +126,9 @@ class OAIAttention(nn.Module):
|
||||
attn_type=AttentionType.DECODER,
|
||||
prefix=f"{prefix}.attn",
|
||||
sinks=self.sinks,
|
||||
rotary_emb=self.rotary_emb
|
||||
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled()
|
||||
else None,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -132,9 +136,9 @@ class OAIAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
# q, k = self.rotary_emb(positions, q, k)
|
||||
v = v.contiguous()
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output = self.attn(q, k, v, positions=positions)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
@ -55,6 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
@ -219,6 +221,9 @@ class LlamaAttention(nn.Module):
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
rotary_emb=self.rotary_emb
|
||||
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled()
|
||||
else None,
|
||||
)
|
||||
|
||||
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
@ -239,11 +244,11 @@ class LlamaAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
# q, k = self.rotary_emb(positions, q, k)
|
||||
if self.do_llama_4_scaling:
|
||||
attn_scale = self._get_llama_4_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
attn_output = self.attn(q, k, v)
|
||||
attn_output = self.attn(q, k, v, positions=positions)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -42,6 +42,11 @@ if current_platform.is_rocm():
|
||||
def num_programs(total_tokens):
|
||||
return min(total_tokens, get_cu_count())
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
|
||||
|
||||
@triton.jit
|
||||
def cp_mha_gather_cache_kernel(
|
||||
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
||||
@ -782,6 +787,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
positions: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with AiterFlashAttention.
|
||||
|
||||
@ -823,29 +829,72 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
# is not padded. However, we don't need to do
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
if (
|
||||
positions is not None
|
||||
and query.shape[0] <= 256
|
||||
and rocm_aiter_ops.is_enabled()
|
||||
):
|
||||
assert self.kv_sharing_target_layer_name is None
|
||||
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
||||
is_neox = self.rotary_emb.is_neox_style
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
query, key, key_cache, value_cache, output = (
|
||||
fused_qk_rope_reshape_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
positions,
|
||||
cos,
|
||||
sin,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
is_neox,
|
||||
flash_layout=True,
|
||||
apply_scale=is_fp8_kv_cache,
|
||||
offs=None,
|
||||
q_out=query,
|
||||
k_out=key,
|
||||
output_zeros=True,
|
||||
zeros_out=output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if positions is not None:
|
||||
if current_platform.is_rocm():
|
||||
query, key = self.rotary_emb.forward_cuda(positions, query, key)
|
||||
else:
|
||||
query, key = self.rotary_emb(positions, query, key)
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
# is not padded. However, we don't need to do
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
@ -31,6 +31,12 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class RocmAttentionMetadata:
|
||||
@ -264,6 +270,7 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
positions: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
@ -306,19 +313,61 @@ class RocmAttentionImpl(AttentionImpl):
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
if (
|
||||
positions is not None
|
||||
and query.shape[0] <= 256
|
||||
and rocm_aiter_ops.is_enabled()
|
||||
):
|
||||
assert self.kv_sharing_target_layer_name is None
|
||||
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
||||
is_neox = self.rotary_emb.is_neox_style
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
query, key, key_cache, value_cache, output = (
|
||||
fused_qk_rope_reshape_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
positions,
|
||||
cos,
|
||||
sin,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
is_neox,
|
||||
flash_layout=False,
|
||||
apply_scale=is_fp8_kv_cache,
|
||||
offs=None,
|
||||
q_out=query,
|
||||
k_out=key,
|
||||
output_zeros=True,
|
||||
zeros_out=output,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if positions is not None:
|
||||
if current_platform.is_rocm():
|
||||
query, key = self.rotary_emb.forward_cuda(positions, query, key)
|
||||
else:
|
||||
query, key = self.rotary_emb(positions, query, key)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
|
||||
@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
@ -379,6 +380,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
attn_metadata: TritonAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
positions: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Paged Attention impl. in Triton.
|
||||
@ -418,30 +420,75 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
# positions is not None entails that Q and K are not RoPE embedded yet,
|
||||
# therefore, either fused_qk_rope_reshape_and_cache or self.rotary_emb is called
|
||||
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
if positions is not None and query.shape[0] <= 256:
|
||||
assert self.kv_sharing_target_layer_name is None, (
|
||||
"self.kv_sharing_target_layer_name cannot be None"
|
||||
)
|
||||
assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}"
|
||||
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
||||
is_neox = self.rotary_emb.is_neox_style
|
||||
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||
if is_fp8_kv_cache:
|
||||
key_cache_og_dtype = key_cache.dtype
|
||||
value_cache_og_dtype = value_cache.dtype
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
query, key, key_cache, value_cache, output = (
|
||||
fused_qk_rope_reshape_and_cache(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
positions,
|
||||
cos,
|
||||
sin,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
is_neox,
|
||||
flash_layout=True,
|
||||
apply_scale=is_fp8_kv_cache,
|
||||
offs=None,
|
||||
q_out=query,
|
||||
k_out=key,
|
||||
output_zeros=True,
|
||||
zeros_out=output,
|
||||
)
|
||||
)
|
||||
if is_fp8_kv_cache:
|
||||
key_cache = key_cache.view(key_cache_og_dtype)
|
||||
value_cache = value_cache.view(value_cache_og_dtype)
|
||||
else:
|
||||
if positions is not None:
|
||||
query, key = self.rotary_emb(positions, query, key)
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user