From f231e5bc21d5c972bffe3aae286c0102c5623e48 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 6 Oct 2025 18:49:23 -0400 Subject: [PATCH] [ROCm] Split AITER unified attention into its own backend (#25507) Signed-off-by: Gregory Shtrasberg --- tests/compile/test_fusion_attn.py | 218 +++++------------- vllm/attention/backends/registry.py | 1 + vllm/attention/selector.py | 1 + vllm/engine/arg_utils.py | 1 + vllm/envs.py | 13 +- vllm/platforms/rocm.py | 28 ++- .../backends/rocm_aiter_unified_attn.py | 203 ++++++++++++++++ vllm/v1/attention/backends/rocm_attn.py | 161 ++++--------- 8 files changed, 325 insertions(+), 301 deletions(-) create mode 100644 vllm/v1/attention/backends/rocm_aiter_unified_attn.py diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 1fd5c267650b..54d3d4ed0295 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -7,9 +7,7 @@ import pytest import torch._dynamo from tests.compile.backend import LazyInitPass, TestBackend -from tests.models.utils import check_outputs_equal from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata -from vllm import LLM, SamplingParams from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention, AttentionMetadata from vllm.attention.backends.registry import _Backend @@ -31,7 +29,6 @@ from vllm.config import ( ) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant, ) @@ -48,132 +45,6 @@ backend: Optional[TestBackend] = None backend_unfused: Optional[TestBackend] = None -@pytest.mark.parametrize( - "model, quant_key", [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)] -) -@pytest.mark.parametrize("use_triton_fa", [True, False]) -@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") -@pytest.mark.skipif( - not current_platform.is_rocm(), reason="V0 attn quant fusion only on ROCm" -) -def test_attention_fusion_v0( - example_prompts, monkeypatch, model: str, quant_key: QuantKey, use_triton_fa: bool -): - # Clean Dynamo cache to avoid reusing other test cases - # (for some reason the reset at the end is not enough) - torch._dynamo.reset() - - # Use global backends - global backend, backend_unfused - - monkeypatch.setenv("VLLM_USE_V1", "1") - monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", str(int(use_triton_fa))) - - # Prompt 4 seems too open-ended, differs between fused and unfused - # (both outputs look reasonable though) - prompts = example_prompts[:4] + example_prompts[5:] - - compile_config = CompilationConfig( - # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation - # DYNAMO_ONCE does not properly propagate shapes. - level=CompilationLevel.DYNAMO_AS_IS, - backend="tests.compile.test_fusion_attn.backend_unfused", - custom_ops=["+quant_fp8"], - ) - vllm_config = VllmConfig( - compilation_config=compile_config, - model_config=ModelConfig( - model=model, - dtype=torch.bfloat16, - ), - ) - backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) - - llm = LLM( - model, - enforce_eager=True, - compilation_config=compile_config, - gpu_memory_utilization=0.5, - max_model_len=2048, - ) - - sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_p=0.95) - - unfused_output = llm.generate(prompts, sampling_params) - backend_unfused = None # Reset backend to make sure llm gets released - del llm - - compile_config = CompilationConfig( - # DYNAMO_AS_IS triggers custom backend & does full Dynamo compilation - # DYNAMO_ONCE does not properly propagate shapes. - level=CompilationLevel.DYNAMO_AS_IS, - backend="tests.compile.test_fusion_attn.backend", - custom_ops=["+quant_fp8"], - ) - vllm_config = VllmConfig( - compilation_config=compile_config, - model_config=ModelConfig( - model=model, - dtype=torch.bfloat16, - ), - ) - - # AttnFusionPass needs attention layers to be registered in config upon init - # so we initialize it during compilation. - attn_pass = LazyInitPass(AttnFusionPass, vllm_config) - backend = TestBackend(NoOpEliminationPass(vllm_config), attn_pass) - llm2 = LLM( - model, - enforce_eager=True, - compilation_config=compile_config, - gpu_memory_utilization=0.5, - max_model_len=2048, - ) - - # check support - attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key) - for key, layer in compile_config.static_forward_context.items() - ] - - print(f"{attn_fusion_supported=}") - if any(attn_fusion_supported): - # Check quant ops - backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False) - - # attention ops present in both, just output_scale param changes - attn_nodes_pre = list(find_op_nodes(ATTN_OP, backend.graph_pre_pass)) - attn_nodes_post = list(find_op_nodes(ATTN_OP, backend.graph_post_pass)) - assert len(attn_nodes_pre) == len(attn_nodes_post) - - for i in range(len(attn_nodes_pre)): - assert attn_nodes_pre[i].kwargs["output_scale"] is None - fused = attn_nodes_post[i].kwargs["output_scale"] is not None - assert fused == attn_fusion_supported[i], ( - f"Node {i} {'' if fused else 'not '} expected to have fused output quant" - ) - - # check outputs - fused_output = llm2.generate(prompts, sampling_params) - - # transform outputs to format expected by check_outputs_equal - sample_outs = lambda s: (list(s.token_ids), s.text) - outs_lst = lambda ros: [sample_outs(ro.outputs[0]) for ro in ros] - - check_outputs_equal( - outputs_0_lst=outs_lst(unfused_output), - outputs_1_lst=outs_lst(fused_output), - name_0="unfused", - name_1="fused", - ) - - # Clean Dynamo cache to avoid polluting other case(s) - torch._dynamo.reset() - - # Reset backend to make sure llm2 gets released - backend = None - - class AttentionQuantPatternModel(torch.nn.Module): """Base model for AttentionQuantPattern fusion.""" @@ -221,7 +92,7 @@ class AttentionQuantPatternModel(torch.nn.Module): device=self.device, ) - def build_attn_metadata(self, batch_size: int, use_hnd: bool) -> AttentionMetadata: + def build_attn_metadata(self, batch_size: int) -> AttentionMetadata: """Initialize attention metadata.""" # Create common attn metadata @@ -232,30 +103,57 @@ class AttentionQuantPatternModel(torch.nn.Module): max_blocks = (max(batch_spec.seq_lens) + self.block_size - 1) // self.block_size num_blocks = batch_size * max_blocks + backend = self.attn.backend - # Create dummy KV cache for FlashInfer TRTLLM - # - NHD: [num_blocks, block_size, num_kv_heads, head_size] - # - HND: [num_blocks, num_kv_heads, block_size, head_size] - kv_cache = torch.zeros( - num_blocks, - 2, - self.num_kv_heads, - self.block_size, - self.head_size, - dtype=self.kv_cache_dtype, - device=self.device, - ) - if current_platform.is_rocm(): + # Create dummy KV cache for the selected backend + if backend == _Backend.ROCM_ATTN: # k/v as 1st dimention - if use_hnd: - kv_cache = kv_cache.permute(1, 0, 2, 3, 4) - else: - kv_cache = kv_cache.permute(1, 0, 3, 2, 4) - else: + # HND: [num_blocks, num_kv_heads, block_size, head_size] + kv_cache = torch.zeros( + 2, + num_blocks, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) + elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + # k/v as 1st dimention + # NHD: [num_blocks, block_size, num_kv_heads, head_size] + kv_cache = torch.zeros( + 2, + num_blocks, + self.block_size, + self.num_kv_heads, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) + elif backend == _Backend.TRITON_ATTN: # k/v as 2nd dimention - # Create kv_cache in HND layout and permute to NHD layout - # (later will be permuted back to HND layout in forward pass) - kv_cache = kv_cache.permute(0, 1, 3, 2, 4) + # NHD: [num_blocks, block_size, num_kv_heads, head_size] + kv_cache = torch.zeros( + num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ) + elif backend == _Backend.FLASHINFER: + kv_cache = torch.zeros( + num_blocks, + 2, + self.num_kv_heads, + self.block_size, + self.head_size, + dtype=self.kv_cache_dtype, + device=self.device, + ).permute(0, 1, 3, 2, 4) + else: + raise ValueError(f"Unsupported backend: {backend}") self.attn.kv_cache = [kv_cache] # Build attn metadata @@ -375,10 +273,9 @@ else: @pytest.mark.parametrize("model_name, model_class", MODELS) @pytest.mark.parametrize( "backend", - [_Backend.FLASHINFER] if current_platform.is_cuda() else [_Backend.TRITON_ATTN], -) -@pytest.mark.parametrize( - "split_attention", [False, True] if current_platform.is_rocm() else [False] + [_Backend.FLASHINFER] + if current_platform.is_cuda() + else [_Backend.ROCM_AITER_UNIFIED_ATTN, _Backend.ROCM_ATTN, _Backend.TRITON_ATTN], ) # TODO(boyuan): test inductor graph partition on rocm @pytest.mark.parametrize( @@ -405,7 +302,6 @@ def test_attention_quant_pattern( model_name: str, model_class: type[AttentionQuantPatternModel], backend: _Backend, - split_attention: bool, use_inductor_graph_partition: bool, monkeypatch, dist_init, @@ -417,8 +313,6 @@ def test_attention_quant_pattern( pytest.skip("inductor graph partition is only available in PyTorch 2.9+") monkeypatch.setenv("VLLM_USE_V1", "1") - if split_attention: - monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1") device = torch.device("cuda:0") torch.manual_seed(42) @@ -466,9 +360,7 @@ def test_attention_quant_pattern( model_unfused = model_unfused.to(device) forward_ctx = get_forward_context() - forward_ctx.attn_metadata = model_unfused.build_attn_metadata( - batch_size, use_hnd=split_attention - ) + forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) # Run model directly without compilation and fusion result_unfused = model_unfused(q, k, v) @@ -494,9 +386,7 @@ def test_attention_quant_pattern( model_fused = model_fused.to(device) forward_ctx = get_forward_context() - forward_ctx.attn_metadata = model_fused.build_attn_metadata( - batch_size, use_hnd=split_attention - ) + forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) # Create test backend with fusion passes enabled noop_pass = NoOpEliminationPass(vllm_config) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 6377e8619b3c..06f13044d572 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -25,3 +25,4 @@ class _Backend(enum.Enum): FLEX_ATTENTION = enum.auto() TREE_ATTN = enum.auto() ROCM_ATTN = enum.auto() + ROCM_AITER_UNIFIED_ATTN = enum.auto() diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index effd35444d54..3a5bbb997286 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -254,3 +254,4 @@ def global_force_attn_backend_context_manager( finally: # Revert the original global backend override, if any global_force_attn_backend(original_value) + _cached_get_attn_backend.cache_clear() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a94ef598f2de..942384688184 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1623,6 +1623,7 @@ class EngineArgs: "TREE_ATTN", "XFORMERS", "ROCM_ATTN", + "ROCM_AITER_UNIFIED_ATTN", ] if ( envs.is_set("VLLM_ATTENTION_BACKEND") diff --git a/vllm/envs.py b/vllm/envs.py index a4f53925626b..2b915b02e48f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -18,7 +18,6 @@ if TYPE_CHECKING: LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False - VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -109,6 +108,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False VLLM_ROCM_USE_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False VLLM_ROCM_USE_SKINNY_GEMM: bool = True VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True @@ -475,10 +475,6 @@ environment_variables: dict[str, Callable[[], Any]] = { os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in ("true", "1") ), - # Use AITER triton unified attention for V1 attention - "VLLM_USE_AITER_UNIFIED_ATTENTION": lambda: ( - os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in ("true", "1") - ), # Force vllm to use a specific flash-attention version (2 or 3), only valid # when using the flash-attention backend. "VLLM_FLASH_ATTN_VERSION": lambda: maybe_convert_int( @@ -896,6 +892,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ROCM_USE_AITER_FP8BMM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_FP8BMM", "True").lower() in ("true", "1") ), + # Use AITER triton unified attention for V1 attention + "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "False").lower() + in ("true", "1") + ), # use rocm skinny gemms "VLLM_ROCM_USE_SKINNY_GEMM": lambda: ( os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in ("true", "1") @@ -1434,7 +1435,6 @@ def compute_hash() -> str: "VLLM_FUSED_MOE_CHUNK_SIZE", "VLLM_FLASHINFER_MOE_BACKEND", "VLLM_V1_USE_PREFILL_DECODE_ATTENTION", - "VLLM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ATTENTION_BACKEND", "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_DISABLED_KERNELS", @@ -1462,6 +1462,7 @@ def compute_hash() -> str: "VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "VLLM_ROCM_USE_TRITON_ROPE", "VLLM_ROCM_USE_AITER_FP8BMM", + "VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION", "VLLM_ROCM_USE_SKINNY_GEMM", "VLLM_ROCM_FP8_PADDING", "VLLM_ROCM_MOE_PADDING", diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 80e7b849c0ed..25601011491f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -276,25 +276,33 @@ class RocmPlatform(Platform): ) if envs.VLLM_USE_V1: - if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9(): - logger.info("Using Flash Attention backend on V1 engine.") + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9() + ) or selected_backend == _Backend.ROCM_AITER_FA: + logger.info("Using Aiter Flash Attention backend on V1 engine.") return ( "vllm.v1.attention.backends." "rocm_aiter_fa.AiterFlashAttentionBackend" ) - elif ( - (envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION) - or envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION + if ( + envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION + ) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN: + logger.info("Using Aiter Unified Attention backend on V1 engine.") + return ( + "vllm.v1.attention.backends." + "rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend" + ) + if ( + envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or selected_backend == _Backend.ROCM_ATTN ): # rocm specific backend, with aiter and/or # triton prefix-prefill - logger.info("Using Rocm/Aiter Attention backend on V1 engine.") + logger.info("Using Rocm Attention backend on V1 engine.") return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend" - else: - # default case, using triton unified attention - logger.info("Using Triton Attention backend on V1 engine.") - return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" + # default case, using triton unified attention + logger.info("Using Triton Attention backend on V1 engine.") + return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" raise RuntimeError( "V0 attention backends have been removed. Set VLLM_USE_V1=1 " "to select a supported backend." diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py new file mode 100644 index 000000000000..235ea1c376ef --- /dev/null +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Attention layer with PagedAttention and Triton prefix prefill.""" + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import AttentionMetadata, AttentionType +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8StaticTensorSym, +) +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.rocm_attn import ( + RocmAttentionBackend, + RocmAttentionImpl, + RocmAttentionMetadata, + RocmAttentionMetadataBuilder, +) + +logger = init_logger(__name__) + + +class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_UNIFIED_ATTN" + + @staticmethod + def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]: + return RocmAiterUnifiedAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return RocmAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @staticmethod + def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]: + return RocmAttentionMetadataBuilder + + +class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): + def fused_output_quant_supported(self, quant_key: QuantKey): + return quant_key == kFp8StaticTensorSym + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + sinks, + ) + logger.info_once( + "Using aiter unified attention for RocmAiterUnifiedAttentionImpl" + ) + from aiter.ops.triton.unified_attention import unified_attention + + self.unified_attention = unified_attention + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if output_block_scale is not None: + raise NotImplementedError( + "fused block_scale output quantization is not yet supported" + " for RocmAttentionImpl" + ) + + if attn_metadata is None: + # Profiling run. + return output + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + + key_cache, value_cache = kv_cache.unbind(0) + + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.kv_cache_dtype.startswith("fp8"): + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + assert layer._q_scale_float == 1.0, ( + "A non 1.0 q_scale is not currently supported." + ) + + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + + self.unified_attention( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + seqused_k=seqused_k, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=block_table, + softcap=self.logits_soft_cap, + q_descale=None, # Not supported + k_descale=layer._k_scale.expand(descale_shape), + v_descale=layer._v_scale.expand(descale_shape), + sinks=self.sinks, + output_scale=output_scale, + ) + + return output diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 4c24770aa22c..10dd01f0a5aa 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -3,13 +3,10 @@ """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass -from functools import cache from typing import ClassVar, Optional import torch -from vllm import _custom_ops as ops -from vllm import envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -96,12 +93,11 @@ class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadat # slow, so here we set it to 1. attn_metadata.seq_lens.fill_(1) - if envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION: - # Here we set the query start locs to 0. This is to - # cover up an invalid memory access in the prefix_prefil kernel - # that we run into during graph capture (#25985) - common_attn_metadata.query_start_loc.zero_() - common_attn_metadata.query_start_loc_cpu.zero_() + # Here we set the query start locs to 0. This is to + # cover up an invalid memory access in the prefix_prefil kernel + # that we run into during graph capture (#25985) + common_attn_metadata.query_start_loc.zero_() + common_attn_metadata.query_start_loc_cpu.zero_() return attn_metadata @@ -211,14 +207,6 @@ class RocmAttentionBackend(AttentionBackend): return RocmAttentionMetadataBuilder -@cache -def use_aiter_unified_attention() -> bool: - """Check if aiter unified attention should be used.""" - # VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set - # to 1 as default - return envs.VLLM_ROCM_USE_AITER and envs.VLLM_USE_AITER_UNIFIED_ATTENTION - - class RocmAttentionImpl(AttentionImpl): def fused_output_quant_supported(self, quant_key: QuantKey): return quant_key == kFp8StaticTensorSym @@ -268,23 +256,6 @@ class RocmAttentionImpl(AttentionImpl): ) self.fp8_dtype = current_platform.fp8_dtype() - self.force_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION - - if not self.force_prefill_decode_attn: - # If not using prefill decode attention, we use the Triton - # unified attention implementation. - if use_aiter_unified_attention(): - logger.info_once("Using aiter unified attention for RocmAttentionImpl") - from aiter.ops.triton.unified_attention import unified_attention - - self.unified_attention = unified_attention - else: - logger.info_once("Using vllm unified attention for RocmAttentionImpl") - from vllm.attention.ops.triton_unified_attention import ( - unified_attention, - ) - - self.unified_attention = unified_attention self.sinks = sinks if sinks is not None: @@ -341,58 +312,32 @@ class RocmAttentionImpl(AttentionImpl): # Whenever making a change in this method, please benchmark the # performance to make sure it does not introduce any overhead. - use_prefill_decode_attn = self.force_prefill_decode_attn num_actual_tokens = attn_metadata.num_actual_tokens - if use_prefill_decode_attn: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size - ) - else: - key_cache, value_cache = kv_cache.unbind(0) + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size + ) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # Skip this if sharing KV cache with an earlier attention layer. - if use_prefill_decode_attn: - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - else: - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): key_cache = key_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype) - num_tokens, num_heads, head_size = query.shape assert layer._q_scale_float == 1.0, ( "A non 1.0 q_scale is not currently supported." ) - if current_platform.is_cuda(): - # Skip Q quantization on ROCm and XPU, enable this on cuda - # only, since dequantizing back to f32 in the attention kernel - # is not supported. - query, _ = ops.scaled_fp8_quant( - query.reshape((num_tokens, num_heads * head_size)).contiguous(), - layer._q_scale, - ) - query = query.reshape((num_tokens, num_heads, head_size)) cu_seqlens_q = attn_metadata.query_start_loc seqused_k = attn_metadata.seq_lens @@ -400,53 +345,27 @@ class RocmAttentionImpl(AttentionImpl): max_seqlen_k = attn_metadata.max_seq_len block_table = attn_metadata.block_table - if use_prefill_decode_attn: - # Compute attention and update output up to `num_actual_tokens`. - chunked_prefill_paged_decode( - query=query[:num_actual_tokens], - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - output=output[:num_actual_tokens], - kv_cache_dtype=self.kv_cache_dtype, - key_cache=key_cache, - value_cache=value_cache, - block_table=block_table, - query_start_loc=cu_seqlens_q, - seq_lens=seqused_k, - max_seq_len=max_seqlen_k, - max_query_len=max_seqlen_q, - k_scale=layer._k_scale, - v_scale=layer._v_scale, - alibi_slopes=self.alibi_slopes, - sliding_window=self.sliding_window[0], - sm_scale=self.scale, - output_scale=output_scale, - sinks=self.sinks, - ) - - else: - descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) - - self.unified_attention( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - out=output[:num_actual_tokens], - cu_seqlens_q=cu_seqlens_q, - max_seqlen_q=max_seqlen_q, - seqused_k=seqused_k, - max_seqlen_k=max_seqlen_k, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, - block_table=block_table, - softcap=self.logits_soft_cap, - q_descale=None, # Not supported - k_descale=layer._k_scale.expand(descale_shape), - v_descale=layer._v_scale.expand(descale_shape), - sinks=self.sinks, - output_scale=output_scale, - ) + # Compute attention and update output up to `num_actual_tokens`. + chunked_prefill_paged_decode( + query=query[:num_actual_tokens], + key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + output=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + key_cache=key_cache, + value_cache=value_cache, + block_table=block_table, + query_start_loc=cu_seqlens_q, + seq_lens=seqused_k, + max_seq_len=max_seqlen_k, + max_query_len=max_seqlen_q, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale, + output_scale=output_scale, + sinks=self.sinks, + ) return output