[ROCm] Split AITER unified attention into its own backend (#25507)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg 2025-10-06 18:49:23 -04:00 committed by GitHub
parent 2161efe978
commit f231e5bc21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 325 additions and 301 deletions

View File

@ -7,9 +7,7 @@ import pytest
import torch._dynamo import torch._dynamo
from tests.compile.backend import LazyInitPass, TestBackend from tests.compile.backend import LazyInitPass, TestBackend
from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.backends.registry import _Backend from vllm.attention.backends.registry import _Backend
@ -31,7 +29,6 @@ from vllm.config import (
) )
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
) )
@ -48,132 +45,6 @@ backend: Optional[TestBackend] = None
backend_unfused: Optional[TestBackend] = None backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize(
"model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]
)
@pytest.mark.parametrize("use_triton_fa", [True, False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(
not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm"
)
def test_attention_fusion_v0(
example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool
):
# Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough)
torch._dynamo.reset()
# Use global backends
global backend, backend_unfused
monkeypatch.setenv("VLLM_USE_V1", "1")
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa)))
# Prompt 4 seems too open-ended, differs between fused and unfused
# (both outputs look reasonable though)
prompts = example_prompts[:4] + example_prompts[5:]
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(
compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
),
)
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(
model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048,
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95)
unfused_output = llm.generate(prompts, sampling_params)
backend_unfused = None # Reset backend to make sure llm gets released
del llm
compile_config = CompilationConfig(
# DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation
# DYNAMO_ONCE does not properly propagate shapes.
level=CompilationLevel.DYNAMO_AS_IS,
backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"],
)
vllm_config = VllmConfig(
compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
),
)
# AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation.
attn_pass = LazyInitPass(AttnFusionPass, vllm_config)
backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass)
llm2 = LLM(
model,
enforce_eager=True,
compilation_config=compile_config,
gpu_memory_utilization=0.5,
max_model_len=2048,
)
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in compile_config.static_forward_context.items()
]
print(f"{attn_fusion_supported=}")
if any(attn_fusion_supported):
# Check quant ops
backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
# attention ops present in both, just output_scale param changes
attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass))
attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass))
assert len(attn_nodes_pre) == len(attn_nodes_post)
for i in range(len(attn_nodes_pre)):
assert attn_nodes_pre[i].kwargs["output_scale"] is None
fused = attn_nodes_post[i].kwargs["output_scale"] is not None
assert fused == attn_fusion_supported[i], (
f"Node {i} {'' if fused else 'not '} expected to have fused output quant"
)
# check outputs
fused_output = llm2.generate(prompts, sampling_params)
# transform outputs to format expected by check_outputs_equal
sample_outs = lambda s: (list(s.token_ids), s.text)
outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros]
check_outputs_equal(
outputs_0_lst=outs_lst(unfused_output),
outputs_1_lst=outs_lst(fused_output),
name_0="unfused",
name_1="fused",
)
# Clean Dynamo cache to avoid polluting other case(s)
torch._dynamo.reset()
# Reset backend to make sure llm2 gets released
backend = None
class AttentionQuantPatternModel(torch.nn.Module): class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion.""" """Base model for AttentionQuantPattern fusion."""
@ -221,7 +92,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device, device=self.device,
) )
def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata: def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
"""Initialize attention metadata.""" """Initialize attention metadata."""
# Create common attn metadata # Create common attn metadata
@ -232,30 +103,57 @@ class AttentionQuantPatternModel(torch.nn.Module):
max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size
num_blocks = batch_size * max_blocks num_blocks = batch_size * max_blocks
backend = self.attn.backend
# Create dummy KV cache for FlashInfer TRTLLM # Create dummy KV cache for the selected backend
# - NHD: [num_blocks, block_size, num_kv_heads, head_size] if backend == _Backend.ROCM_ATTN:
# - HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
if current_platform.is_rocm():
# k/v as 1st dimention # k/v as 1st dimention
if use_hnd: # HND: [num_blocks, num_kv_heads, block_size, head_size]
kv_cache = kv_cache.permute(1, 0, 2, 3, 4) kv_cache = torch.zeros(
else: 2,
kv_cache = kv_cache.permute(1, 0, 3, 2, 4) num_blocks,
else: self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
# k/v as 1st dimention
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
kv_cache = torch.zeros(
2,
num_blocks,
self.block_size,
self.num_kv_heads,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.TRITON_ATTN:
# k/v as 2nd dimention # k/v as 2nd dimention
# Create kv_cache in HND layout and permute to NHD layout # NHD: [num_blocks, block_size, num_kv_heads, head_size]
# (later will be permuted back to HND layout in forward pass) kv_cache = torch.zeros(
kv_cache = kv_cache.permute(0, 1, 3, 2, 4) num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
)
elif backend == _Backend.FLASHINFER:
kv_cache = torch.zeros(
num_blocks,
2,
self.num_kv_heads,
self.block_size,
self.head_size,
dtype=self.kv_cache_dtype,
device=self.device,
).permute(0, 1, 3, 2, 4)
else:
raise ValueError(f"Unsupported backend: {backend}")
self.attn.kv_cache = [kv_cache] self.attn.kv_cache = [kv_cache]
# Build attn metadata # Build attn metadata
@ -375,10 +273,9 @@ else:
@pytest.mark.parametrize("model_name, model_class", MODELS) @pytest.mark.parametrize("model_name, model_class", MODELS)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend", "backend",
[_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN], [_Backend.FLASHINFER]
) if current_platform.is_cuda()
@pytest.mark.parametrize( else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN],
"split_attention", [False, True] if current_platform.is_rocm() else [False]
) )
# TODO(boyuan): test inductor graph partition on rocm # TODO(boyuan): test inductor graph partition on rocm
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -405,7 +302,6 @@ def test_attention_quant_pattern(
model_name: str, model_name: str,
model_class: type[AttentionQuantPatternModel], model_class: type[AttentionQuantPatternModel],
backend: _Backend, backend: _Backend,
split_attention: bool,
use_inductor_graph_partition: bool, use_inductor_graph_partition: bool,
monkeypatch, monkeypatch,
dist_init, dist_init,
@ -417,8 +313,6 @@ def test_attention_quant_pattern(
pytest.skip("inductor graph partition is only available in PyTorch 2.9+") pytest.skip("inductor graph partition is only available in PyTorch 2.9+")
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.manual_seed(42) torch.manual_seed(42)
@ -466,9 +360,7 @@ def test_attention_quant_pattern(
model_unfused = model_unfused.to(device) model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata( forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
batch_size, use_hnd=split_attention
)
# Run model directly without compilation and fusion # Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v) result_unfused = model_unfused(q, k, v)
@ -494,9 +386,7 @@ def test_attention_quant_pattern(
model_fused = model_fused.to(device) model_fused = model_fused.to(device)
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata( forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
batch_size, use_hnd=split_attention
)
# Create test backend with fusion passes enabled # Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)

View File

@ -25,3 +25,4 @@ class _Backend(enum.Enum):
FLEX_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto()
TREE_ATTN = enum.auto() TREE_ATTN = enum.auto()
ROCM_ATTN = enum.auto() ROCM_ATTN = enum.auto()
ROCM_AITER_UNIFIED_ATTN = enum.auto()

View File

@ -254,3 +254,4 @@ def global_force_attn_backend_context_manager(
finally: finally:
# Revert the original global backend override, if any # Revert the original global backend override, if any
global_force_attn_backend(original_value) global_force_attn_backend(original_value)
_cached_get_attn_backend.cache_clear()

View File

@ -1623,6 +1623,7 @@ class EngineArgs:
"TREE_ATTN", "TREE_ATTN",
"XFORMERS", "XFORMERS",
"ROCM_ATTN", "ROCM_ATTN",
"ROCM_AITER_UNIFIED_ATTN",
] ]
if ( if (
envs.is_set("VLLM_ATTENTION_BACKEND") envs.is_set("VLLM_ATTENTION_BACKEND")

View File

@ -18,7 +18,6 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None CUDA_VISIBLE_DEVICES: Optional[str] = None
@ -109,6 +108,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_FP8_PADDING: bool = True
VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True
@ -475,10 +475,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower()
in ("true", "1") in ("true", "1")
), ),
# Use AITER triton unified attention for V1 attention
"VLLM_USE_AITER_UNIFIED_ATTENTION": lambda: (
os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1")
),
# Force vllm to use a specific flash-attention version (2 or 3), only valid # Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend. # when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int( "VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int(
@ -896,6 +892,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( "VLLM_ROCM_USE_AITER_FP8BMM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1")
), ),
# Use AITER triton unified attention for V1 attention
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower()
in ("true", "1")
),
# use rocm skinny gemms # use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( "VLLM_ROCM_USE_SKINNY_GEMM": lambda: (
os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1")
@ -1434,7 +1435,6 @@ def compute_hash() -> str:
"VLLM_FUSED_MOE_CHUNK_SIZE", "VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_FLASHINFER_MOE_BACKEND", "VLLM_FLASHINFER_MOE_BACKEND",
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION",
"VLLM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ATTENTION_BACKEND", "VLLM_ATTENTION_BACKEND",
"VLLM_USE_FLASHINFER_SAMPLER", "VLLM_USE_FLASHINFER_SAMPLER",
"VLLM_DISABLED_KERNELS", "VLLM_DISABLED_KERNELS",
@ -1462,6 +1462,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_TRITON_ROPE", "VLLM_ROCM_USE_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM", "VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION",
"VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING", "VLLM_ROCM_FP8_PADDING",
"VLLM_ROCM_MOE_PADDING", "VLLM_ROCM_MOE_PADDING",

View File

@ -276,25 +276,33 @@ class RocmPlatform(Platform):
) )
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): if (
logger.info("Using Flash Attention backend on V1 engine.") envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend on V1 engine.")
return ( return (
"vllm.v1.attention.backends." "vllm.v1.attention.backends."
"rocm_aiter_fa.AiterFlashAttentionBackend" "rocm_aiter_fa.AiterFlashAttentionBackend"
) )
elif ( if (
(envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION) envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend on V1 engine.")
return (
"vllm.v1.attention.backends."
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
)
if (
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
or selected_backend == _Backend.ROCM_ATTN or selected_backend == _Backend.ROCM_ATTN
): ):
# rocm specific backend, with aiter and/or # rocm specific backend, with aiter and/or
# triton prefix-prefill # triton prefix-prefill
logger.info("Using Rocm/Aiter Attention backend on V1 engine.") logger.info("Using Rocm Attention backend on V1 engine.")
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
else: # default case, using triton unified attention
# default case, using triton unified attention logger.info("Using Triton Attention backend on V1 engine.")
logger.info("Using Triton Attention backend on V1 engine.") return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
raise RuntimeError( raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 " "V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend." "to select a supported backend."

View File

@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import AttentionMetadata, AttentionType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend,
RocmAttentionImpl,
RocmAttentionMetadata,
RocmAttentionMetadataBuilder,
)
logger = init_logger(__name__)
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ROCM_AITER_UNIFIED_ATTN"
@staticmethod
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
return RocmAiterUnifiedAttentionImpl
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return RocmAttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> tuple[int, ...]:
if block_size % 16 != 0:
raise ValueError("Block size must be a multiple of 16.")
return (2, num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@staticmethod
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
return RocmAttentionMetadataBuilder
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
super().__init__(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
sinks,
)
logger.info_once(
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
)
from aiter.ops.triton.unified_attention import unified_attention
self.unified_attention = unified_attention
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape =
[2, num_blocks, block_size, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert output is not None, "Output tensor must be provided."
if output_block_scale is not None:
raise NotImplementedError(
"fused block_scale output quantization is not yet supported"
" for RocmAttentionImpl"
)
if attn_metadata is None:
# Profiling run.
return output
assert attn_metadata.use_cascade is False
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# are surprisingly slow even in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
)
return output

View File

@ -3,13 +3,10 @@
"""Attention layer with PagedAttention and Triton prefix prefill.""" """Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass from dataclasses import dataclass
from functools import cache
from typing import ClassVar, Optional from typing import ClassVar, Optional
import torch import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import ( from vllm.attention.backends.abstract import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
@ -96,12 +93,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat
# slow, so here we set it to 1. # slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1) attn_metadata.seq_lens.fill_(1)
if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION: # Here we set the query start locs to 0. This is to
# Here we set the query start locs to 0. This is to # cover up an invalid memory access in the prefix_prefil kernel
# cover up an invalid memory access in the prefix_prefil kernel # that we run into during graph capture (#25985)
# that we run into during graph capture (#25985) common_attn_metadata.query_start_loc.zero_()
common_attn_metadata.query_start_loc.zero_() common_attn_metadata.query_start_loc_cpu.zero_()
common_attn_metadata.query_start_loc_cpu.zero_()
return attn_metadata return attn_metadata
@ -211,14 +207,6 @@ class RocmAttentionBackend(AttentionBackend):
return RocmAttentionMetadataBuilder return RocmAttentionMetadataBuilder
@cache
def use_aiter_unified_attention() -> bool:
"""Check if aiter unified attention should be used."""
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
# to 1 as default
return envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
class RocmAttentionImpl(AttentionImpl): class RocmAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym return quant_key == kFp8StaticTensorSym
@ -268,23 +256,6 @@ class RocmAttentionImpl(AttentionImpl):
) )
self.fp8_dtype = current_platform.fp8_dtype() self.fp8_dtype = current_platform.fp8_dtype()
self.force_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
if not self.force_prefill_decode_attn:
# If not using prefill decode attention, we use the Triton
# unified attention implementation.
if use_aiter_unified_attention():
logger.info_once("Using aiter unified attention for RocmAttentionImpl")
from aiter.ops.triton.unified_attention import unified_attention
self.unified_attention = unified_attention
else:
logger.info_once("Using vllm unified attention for RocmAttentionImpl")
from vllm.attention.ops.triton_unified_attention import (
unified_attention,
)
self.unified_attention = unified_attention
self.sinks = sinks self.sinks = sinks
if sinks is not None: if sinks is not None:
@ -341,58 +312,32 @@ class RocmAttentionImpl(AttentionImpl):
# Whenever making a change in this method, please benchmark the # Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead. # performance to make sure it does not introduce any overhead.
use_prefill_decode_attn = self.force_prefill_decode_attn
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
if use_prefill_decode_attn: key_cache, value_cache = PagedAttention.split_kv_cache(
key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size
kv_cache, self.num_kv_heads, self.head_size )
)
else:
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None: if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer. # Skip this if sharing KV cache with an earlier attention layer.
if use_prefill_decode_attn: PagedAttention.write_to_paged_cache(
PagedAttention.write_to_paged_cache( key,
key, value,
value, key_cache,
key_cache, value_cache,
value_cache, attn_metadata.slot_mapping,
attn_metadata.slot_mapping, self.kv_cache_dtype,
self.kv_cache_dtype, layer._k_scale,
layer._k_scale, layer._v_scale,
layer._v_scale, )
)
else:
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, ( assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported." "A non 1.0 q_scale is not currently supported."
) )
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))
cu_seqlens_q = attn_metadata.query_start_loc cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens seqused_k = attn_metadata.seq_lens
@ -400,53 +345,27 @@ class RocmAttentionImpl(AttentionImpl):
max_seqlen_k = attn_metadata.max_seq_len max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table block_table = attn_metadata.block_table
if use_prefill_decode_attn: # Compute attention and update output up to `num_actual_tokens`.
# Compute attention and update output up to `num_actual_tokens`. chunked_prefill_paged_decode(
chunked_prefill_paged_decode( query=query[:num_actual_tokens],
query=query[:num_actual_tokens], key=key[:num_actual_tokens],
key=key[:num_actual_tokens], value=value[:num_actual_tokens],
value=value[:num_actual_tokens], output=output[:num_actual_tokens],
output=output[:num_actual_tokens], kv_cache_dtype=self.kv_cache_dtype,
kv_cache_dtype=self.kv_cache_dtype, key_cache=key_cache,
key_cache=key_cache, value_cache=value_cache,
value_cache=value_cache, block_table=block_table,
block_table=block_table, query_start_loc=cu_seqlens_q,
query_start_loc=cu_seqlens_q, seq_lens=seqused_k,
seq_lens=seqused_k, max_seq_len=max_seqlen_k,
max_seq_len=max_seqlen_k, max_query_len=max_seqlen_q,
max_query_len=max_seqlen_q, k_scale=layer._k_scale,
k_scale=layer._k_scale, v_scale=layer._v_scale,
v_scale=layer._v_scale, alibi_slopes=self.alibi_slopes,
alibi_slopes=self.alibi_slopes, sliding_window=self.sliding_window[0],
sliding_window=self.sliding_window[0], sm_scale=self.scale,
sm_scale=self.scale, output_scale=output_scale,
output_scale=output_scale, sinks=self.sinks,
sinks=self.sinks, )
)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
)
return output return output