mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 18:35:42 +08:00
Merge ROCm:rope_kvcache_fusion
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
cee749f6f4
commit
25b5317ce4
@ -142,6 +142,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
attn_type: str = AttentionType.DECODER,
|
attn_type: str = AttentionType.DECODER,
|
||||||
kv_sharing_target_layer_name: str | None = None,
|
kv_sharing_target_layer_name: str | None = None,
|
||||||
attn_backend: type[AttentionBackend] | None = None,
|
attn_backend: type[AttentionBackend] | None = None,
|
||||||
|
rotary_emb: nn.Module | None = None,
|
||||||
**extra_impl_args,
|
**extra_impl_args,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -207,7 +208,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.attn_backend = attn_backend
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
# prefix caching + batch invariance is currently not supported for
|
# prefix caching + batch invariance is currently not supported for
|
||||||
# FLASHINFER and TRITON_MLA.
|
# FLASHINFER and TRITON_MLA.
|
||||||
if (
|
if (
|
||||||
@ -240,6 +240,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
kv_sharing_target_layer_name,
|
kv_sharing_target_layer_name,
|
||||||
**extra_impl_args,
|
**extra_impl_args,
|
||||||
)
|
)
|
||||||
|
self.impl.rotary_emb = rotary_emb
|
||||||
backend_name = self.attn_backend.get_name()
|
backend_name = self.attn_backend.get_name()
|
||||||
self.backend = AttentionBackendEnum.__members__.get(backend_name)
|
self.backend = AttentionBackendEnum.__members__.get(backend_name)
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
@ -296,6 +297,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
# shape does not match the query shape, so we optionally let the model
|
# shape does not match the query shape, so we optionally let the model
|
||||||
# definition specify the output tensor shape.
|
# definition specify the output tensor shape.
|
||||||
output_shape: torch.Size | None = None,
|
output_shape: torch.Size | None = None,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
The KV cache is stored inside this class and is accessed via
|
The KV cache is stored inside this class and is accessed via
|
||||||
@ -345,7 +347,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
torch.ops.vllm.unified_attention_with_output(
|
torch.ops.vllm.unified_attention_with_output(
|
||||||
query, key, value, output, self.layer_name
|
query, key, value, output, self.layer_name, positions=positions
|
||||||
)
|
)
|
||||||
return output.view(-1, hidden_size)
|
return output.view(-1, hidden_size)
|
||||||
else:
|
else:
|
||||||
@ -564,6 +566,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_sparse: bool = False,
|
use_sparse: bool = False,
|
||||||
indexer: object | None = None,
|
indexer: object | None = None,
|
||||||
|
rotary_emb: nn.Module | None = None,
|
||||||
**extra_impl_args,
|
**extra_impl_args,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -641,6 +644,8 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
**extra_impl_args,
|
**extra_impl_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.impl.rotary_emb = rotary_emb
|
||||||
|
|
||||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||||
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
@ -668,6 +673,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
kv_c_normed: torch.Tensor,
|
kv_c_normed: torch.Tensor,
|
||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
output_shape: torch.Size | None = None,
|
output_shape: torch.Size | None = None,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.calculate_kv_scales:
|
if self.calculate_kv_scales:
|
||||||
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
|
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
|
||||||
@ -704,6 +710,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
|||||||
k_pe,
|
k_pe,
|
||||||
output,
|
output,
|
||||||
self.layer_name,
|
self.layer_name,
|
||||||
|
positions=positions,
|
||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
@ -858,8 +865,24 @@ def unified_attention_with_output(
|
|||||||
layer_name: str,
|
layer_name: str,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
attn_metadata, self, kv_cache = get_attention_context(layer_name)
|
||||||
|
if positions is not None:
|
||||||
|
assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None
|
||||||
|
self.impl.forward(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
output=output,
|
||||||
|
output_scale=output_scale,
|
||||||
|
output_block_scale=output_block_scale,
|
||||||
|
positions=positions,
|
||||||
|
)
|
||||||
|
return
|
||||||
self.impl.forward(
|
self.impl.forward(
|
||||||
self,
|
self,
|
||||||
query,
|
query,
|
||||||
@ -881,6 +904,7 @@ def unified_attention_with_output_fake(
|
|||||||
layer_name: str,
|
layer_name: str,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -931,6 +955,7 @@ def unified_mla_attention_with_output(
|
|||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -954,6 +979,7 @@ def unified_mla_attention_with_output_fake(
|
|||||||
k_pe: torch.Tensor,
|
k_pe: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
layer_name: str,
|
layer_name: str,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch.distributed as dist
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import GptOssConfig
|
from transformers import GptOssConfig
|
||||||
|
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
@ -125,6 +126,9 @@ class OAIAttention(nn.Module):
|
|||||||
attn_type=AttentionType.DECODER,
|
attn_type=AttentionType.DECODER,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
|
rotary_emb=self.rotary_emb
|
||||||
|
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled()
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -132,9 +136,9 @@ class OAIAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
# q, k = self.rotary_emb(positions, q, k)
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v, positions=positions)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
@ -55,6 +56,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
maybe_remap_kv_scale_name,
|
maybe_remap_kv_scale_name,
|
||||||
)
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||||
@ -219,6 +221,9 @@ class LlamaAttention(nn.Module):
|
|||||||
per_layer_sliding_window=sliding_window,
|
per_layer_sliding_window=sliding_window,
|
||||||
attn_type=attn_type,
|
attn_type=attn_type,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
|
rotary_emb=self.rotary_emb
|
||||||
|
if current_platform.is_rocm() and rocm_aiter_ops.is_enabled()
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
def _get_llama_4_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||||
@ -239,11 +244,11 @@ class LlamaAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
# q, k = self.rotary_emb(positions, q, k)
|
||||||
if self.do_llama_4_scaling:
|
if self.do_llama_4_scaling:
|
||||||
attn_scale = self._get_llama_4_attn_scale(positions)
|
attn_scale = self._get_llama_4_attn_scale(positions)
|
||||||
q = (q * attn_scale).to(q.dtype)
|
q = (q * attn_scale).to(q.dtype)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v, positions=positions)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -42,6 +42,11 @@ if current_platform.is_rocm():
|
|||||||
def num_programs(total_tokens):
|
def num_programs(total_tokens):
|
||||||
return min(total_tokens, get_cu_count())
|
return min(total_tokens, get_cu_count())
|
||||||
|
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def cp_mha_gather_cache_kernel(
|
def cp_mha_gather_cache_kernel(
|
||||||
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
||||||
@ -782,6 +787,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
output: torch.Tensor | None = None,
|
output: torch.Tensor | None = None,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with AiterFlashAttention.
|
"""Forward pass with AiterFlashAttention.
|
||||||
|
|
||||||
@ -823,29 +829,72 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
# key and value may be None in the case of cross attention. They are
|
# key and value may be None in the case of cross attention. They are
|
||||||
# calculated once based on the output from the encoder and then cached
|
# calculated once based on the output from the encoder and then cached
|
||||||
# in KV cache.
|
# in KV cache.
|
||||||
if (
|
|
||||||
self.kv_sharing_target_layer_name is None
|
|
||||||
and key is not None
|
|
||||||
and value is not None
|
|
||||||
):
|
|
||||||
# Reshape the input keys and values and store them in the cache.
|
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
|
||||||
# is not padded. However, we don't need to do
|
|
||||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
|
||||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
|
||||||
# to determine the number of actual tokens.
|
|
||||||
|
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
if (
|
||||||
key,
|
positions is not None
|
||||||
value,
|
and query.shape[0] <= 256
|
||||||
key_cache,
|
and rocm_aiter_ops.is_enabled()
|
||||||
value_cache,
|
):
|
||||||
attn_metadata.slot_mapping,
|
assert self.kv_sharing_target_layer_name is None
|
||||||
self.kv_cache_dtype,
|
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
||||||
layer._k_scale,
|
is_neox = self.rotary_emb.is_neox_style
|
||||||
layer._v_scale,
|
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||||
|
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||||
|
if is_fp8_kv_cache:
|
||||||
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
|
query, key, key_cache, value_cache, output = (
|
||||||
|
fused_qk_rope_reshape_and_cache(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
is_neox,
|
||||||
|
flash_layout=True,
|
||||||
|
apply_scale=is_fp8_kv_cache,
|
||||||
|
offs=None,
|
||||||
|
q_out=query,
|
||||||
|
k_out=key,
|
||||||
|
output_zeros=True,
|
||||||
|
zeros_out=output,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if positions is not None:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
query, key = self.rotary_emb.forward_cuda(positions, query, key)
|
||||||
|
else:
|
||||||
|
query, key = self.rotary_emb(positions, query, key)
|
||||||
|
if (
|
||||||
|
self.kv_sharing_target_layer_name is None
|
||||||
|
and key is not None
|
||||||
|
and value is not None
|
||||||
|
):
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||||
|
# is not padded. However, we don't need to do
|
||||||
|
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||||
|
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||||
|
# to determine the number of actual tokens.
|
||||||
|
|
||||||
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
|
|||||||
@ -31,6 +31,12 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RocmAttentionMetadata:
|
class RocmAttentionMetadata:
|
||||||
@ -264,6 +270,7 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
output: torch.Tensor | None = None,
|
output: torch.Tensor | None = None,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
|
|
||||||
@ -306,19 +313,61 @@ class RocmAttentionImpl(AttentionImpl):
|
|||||||
kv_cache, self.num_kv_heads, self.head_size
|
kv_cache, self.num_kv_heads, self.head_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_sharing_target_layer_name is None:
|
if (
|
||||||
# Reshape the input keys and values and store them in the cache.
|
positions is not None
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
and query.shape[0] <= 256
|
||||||
PagedAttention.write_to_paged_cache(
|
and rocm_aiter_ops.is_enabled()
|
||||||
key,
|
):
|
||||||
value,
|
assert self.kv_sharing_target_layer_name is None
|
||||||
key_cache,
|
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
||||||
value_cache,
|
is_neox = self.rotary_emb.is_neox_style
|
||||||
attn_metadata.slot_mapping,
|
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||||
self.kv_cache_dtype,
|
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||||
layer._k_scale,
|
if is_fp8_kv_cache:
|
||||||
layer._v_scale,
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
|
query, key, key_cache, value_cache, output = (
|
||||||
|
fused_qk_rope_reshape_and_cache(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
positions,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
is_neox,
|
||||||
|
flash_layout=False,
|
||||||
|
apply_scale=is_fp8_kv_cache,
|
||||||
|
offs=None,
|
||||||
|
q_out=query,
|
||||||
|
k_out=key,
|
||||||
|
output_zeros=True,
|
||||||
|
zeros_out=output,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if positions is not None:
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
query, key = self.rotary_emb.forward_cuda(positions, query, key)
|
||||||
|
else:
|
||||||
|
query, key = self.rotary_emb(positions, query, key)
|
||||||
|
if self.kv_sharing_target_layer_name is None:
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
PagedAttention.write_to_paged_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||||||
from typing import ClassVar
|
from typing import ClassVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
@ -379,6 +380,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
attn_metadata: TritonAttentionMetadata,
|
attn_metadata: TritonAttentionMetadata,
|
||||||
output: torch.Tensor | None = None,
|
output: torch.Tensor | None = None,
|
||||||
output_scale: torch.Tensor | None = None,
|
output_scale: torch.Tensor | None = None,
|
||||||
|
positions: torch.Tensor | None = None,
|
||||||
output_block_scale: torch.Tensor | None = None,
|
output_block_scale: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Paged Attention impl. in Triton.
|
"""Forward pass with Paged Attention impl. in Triton.
|
||||||
@ -418,30 +420,75 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
key_cache, value_cache = kv_cache.unbind(1)
|
key_cache, value_cache = kv_cache.unbind(1)
|
||||||
|
# positions is not None entails that Q and K are not RoPE embedded yet,
|
||||||
|
# therefore, either fused_qk_rope_reshape_and_cache or self.rotary_emb is called
|
||||||
|
|
||||||
if (
|
if positions is not None and query.shape[0] <= 256:
|
||||||
self.kv_sharing_target_layer_name is None
|
assert self.kv_sharing_target_layer_name is None, (
|
||||||
and key is not None
|
"self.kv_sharing_target_layer_name cannot be None"
|
||||||
and value is not None
|
)
|
||||||
):
|
assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}"
|
||||||
# Reshape the input keys and values and store them in the cache.
|
cos_sin_cache = self.rotary_emb.cos_sin_cache
|
||||||
# Skip this if sharing KV cache with an earlier attention layer.
|
is_neox = self.rotary_emb.is_neox_style
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||||
|
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
||||||
|
if is_fp8_kv_cache:
|
||||||
|
key_cache_og_dtype = key_cache.dtype
|
||||||
|
value_cache_og_dtype = value_cache.dtype
|
||||||
key_cache = key_cache.view(self.fp8_dtype)
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
value_cache = value_cache.view(self.fp8_dtype)
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
# triton kernel does not support uint8 kv_cache
|
query, key, key_cache, value_cache, output = (
|
||||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
fused_qk_rope_reshape_and_cache(
|
||||||
# are not supported)
|
query,
|
||||||
triton_reshape_and_cache_flash(
|
key,
|
||||||
key,
|
value,
|
||||||
value,
|
key_cache,
|
||||||
key_cache,
|
value_cache,
|
||||||
value_cache,
|
attn_metadata.slot_mapping,
|
||||||
attn_metadata.slot_mapping,
|
positions,
|
||||||
self.kv_cache_dtype,
|
cos,
|
||||||
layer._k_scale,
|
sin,
|
||||||
layer._v_scale,
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
is_neox,
|
||||||
|
flash_layout=True,
|
||||||
|
apply_scale=is_fp8_kv_cache,
|
||||||
|
offs=None,
|
||||||
|
q_out=query,
|
||||||
|
k_out=key,
|
||||||
|
output_zeros=True,
|
||||||
|
zeros_out=output,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
if is_fp8_kv_cache:
|
||||||
|
key_cache = key_cache.view(key_cache_og_dtype)
|
||||||
|
value_cache = value_cache.view(value_cache_og_dtype)
|
||||||
|
else:
|
||||||
|
if positions is not None:
|
||||||
|
query, key = self.rotary_emb(positions, query, key)
|
||||||
|
if (
|
||||||
|
self.kv_sharing_target_layer_name is None
|
||||||
|
and key is not None
|
||||||
|
and value is not None
|
||||||
|
):
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# Skip this if sharing KV cache with an earlier attention layer.
|
||||||
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
key_cache = key_cache.view(self.fp8_dtype)
|
||||||
|
value_cache = value_cache.view(self.fp8_dtype)
|
||||||
|
# triton kernel does not support uint8 kv_cache
|
||||||
|
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||||
|
# are not supported)
|
||||||
|
triton_reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
if key_cache.dtype != self.fp8_dtype:
|
if key_cache.dtype != self.fp8_dtype:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user