diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index a6c953ee0eac9..32960cc8073bb 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -552,7 +552,11 @@ __global__ void indexer_k_quant_and_cache_kernel( #ifndef USE_ROCM __syncwarp(); #endif +#if defined(__gfx942__) + float scale = fmaxf(amax, 1e-4) / 224.0f; +#else float scale = fmaxf(amax, 1e-4) / 448.0f; +#endif if (use_ue8m0) { scale = exp2f(ceilf(log2f(scale))); } diff --git a/vllm/attention/ops/rocm_aiter_mla_sparse.py b/vllm/attention/ops/rocm_aiter_mla_sparse.py new file mode 100644 index 0000000000000..080e92ecc9408 --- /dev/null +++ b/vllm/attention/ops/rocm_aiter_mla_sparse.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib +from functools import lru_cache + +import torch + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84 +def fp8_mqa_logits_torch( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + kv, scale = kv + seq_len_kv = kv.shape[0] + k = kv.to(torch.bfloat16) + q = q.to(torch.bfloat16) + + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k).float() * scale + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + + +def rocm_fp8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + + # TODO(ganyi): Temporarily workaround, will remove the module check and reference + # path after aiter merge this kernel into main + @lru_cache + def has_mqa_logits_module(): + return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None + + if rocm_aiter_ops.is_enabled() and has_mqa_logits_module(): + from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits + + kv, scale = kv + return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) + else: + return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156 +def fp8_paged_mqa_logits_torch( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +): + from vllm.utils.math_utils import cdiv + + fp8_dtype = current_platform.fp8_dtype() + batch_size, next_n, _, dim = q.size() + kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:] + scale = scale.contiguous().view(torch.float) + q = q.float() + kv_cache = kv_cache.view(fp8_dtype).float() * scale + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() + ) + for block_rk in range(cdiv(context_len, block_size)): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange( + block_rk * block_size, (block_rk + 1) * block_size, device="cuda" + ) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) + s = torch.where( + mask[None, :, :], + (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( + logits.dtype + ), + float("-inf"), + ) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[ + i * next_n : (i + 1) * next_n, + block_rk * block_size : (block_rk + 1) * block_size, + ] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf")) + return logits + + +def rocm_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache. + + Args: + q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + + if rocm_aiter_ops.is_enabled(): + from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1 + + batch_size, next_n, heads, _ = q_fp8.shape + out_qk = torch.full( + (heads, batch_size * next_n, max_model_len), + float("-inf"), + device="cuda", + dtype=torch.float32, + ) + deepgemm_fp8_paged_mqa_logits_stage1( + q_fp8, + kv_cache_fp8, + weights, + out_qk, + context_lens, + block_tables, + max_model_len, + ) + return out_qk.sum(dim=0) + else: + return fp8_paged_mqa_logits_torch( + q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len + ) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d0a116b97997a..7cfd381592b49 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -594,6 +594,7 @@ def sparse_attn_indexer( ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): return sparse_attn_indexer_fake( @@ -633,7 +634,7 @@ def sparse_attn_indexer( k_fp8 = torch.empty( [chunk.total_seq_lens, head_dim], device=k.device, - dtype=torch.float8_e4m3fn, + dtype=fp8_dtype, ) k_scale = torch.empty( [chunk.total_seq_lens, 4], @@ -647,7 +648,12 @@ def sparse_attn_indexer( chunk.block_table, chunk.cu_seq_lens, ) - logits = fp8_mqa_logits( + fp8_mqa_logits_func = fp8_mqa_logits + if current_platform.is_rocm(): + from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits + + fp8_mqa_logits_func = rocm_fp8_mqa_logits + logits = fp8_mqa_logits_func( q_fp8[chunk.token_start : chunk.token_end], (k_fp8, k_scale.view(torch.float32)), weights[chunk.token_start : chunk.token_end], @@ -692,7 +698,14 @@ def sparse_attn_indexer( next_n = padded_q_fp8_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n - logits = fp8_paged_mqa_logits( + fp8_paged_mqa_logits_func = fp8_paged_mqa_logits + if current_platform.is_rocm(): + from vllm.attention.ops.rocm_aiter_mla_sparse import ( + rocm_fp8_paged_mqa_logits, + ) + + fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits + logits = fp8_paged_mqa_logits_func( padded_q_fp8_decode_tokens, kv_cache, weights[:num_padded_tokens], @@ -749,7 +762,8 @@ def sparse_attn_indexer_fake( _flattened_kv = torch.empty( [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8 ) - _k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous() + fp8_dtype = current_platform.fp8_dtype() + _k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() return topk_indices_buffer diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index f07f068a9249b..1a2f9226ddce8 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -225,7 +225,18 @@ class RocmPlatform(Platform): from vllm.attention.backends.registry import AttentionBackendEnum if use_sparse: - raise NotImplementedError("Sparse Attention is not supported on ROCm.") + if kv_cache_dtype.startswith("fp8"): + raise ValueError( + "ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype." + ) + assert block_size == 1, ( + "Sparse MLA backend on ROCm only supports block size 1 for now." + ) + logger.info_once("Using Sparse MLA backend on V1 engine.") + return ( + "vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse." + "ROCMAiterMLASparseBackend" + ) if use_mla: if selected_backend is None: diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b0a383a0e28c..b25c1e3e1ece3 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -325,6 +325,7 @@ DEFAULT_BLOCK_SIZE = [128, 128] def per_block_cast_to_fp8( x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False ) -> tuple[torch.Tensor, torch.Tensor]: + fp8_dtype = current_platform.fp8_dtype() assert x.dim() == 2 m, n = x.shape block_m, block_n = block_size @@ -334,9 +335,9 @@ def per_block_cast_to_fp8( x_padded[:m, :n] = x x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 + sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0 sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf - x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( x_view.size(0), x_view.size(2) ) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index bb8d914d15719..3f2cc8c38327e 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -168,7 +168,7 @@ def _convert_req_index_to_global_index_kernel( inblock_off = tok % BLOCK_SIZE # Guard block_table access - valid_block = block_id < max_num_blocks_per_req + valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0) bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 base = tl.load(bt_ptr, mask=valid_block, other=0) diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 37aa5dad89a0e..cc0988435768c 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import ( ) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata +from vllm.platforms import current_platform +from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -23,7 +24,9 @@ logger = init_logger(__name__) class DeepseekV32IndexerBackend(AttentionBackend): - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64] + supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [ + 1 if current_platform.is_rocm() else 64 + ] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -328,10 +331,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() seq_lens = common_attn_metadata.seq_lens[:num_decodes] - - self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( - seq_lens, self.kv_cache_spec.block_size, self.num_sms - ) + if is_deep_gemm_supported(): + self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( + seq_lens, self.kv_cache_spec.block_size, self.num_sms + ) decode_metadata = DeepSeekV32IndexerDecodeMetadata( block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...], seq_lens=common_attn_metadata.seq_lens[:num_decodes], diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py new file mode 100644 index 0000000000000..c0e7f0e380b98 --- /dev/null +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Optional + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionLayer, + AttentionMetadata, +) +from vllm.attention.backends.utils import get_mla_dims +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import ( + MLACommonBaseImpl, +) +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + triton_convert_req_index_to_global_index, +) +from vllm.v1.attention.backends.utils import ( + AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.kv_cache_interface import AttentionSpec + +if TYPE_CHECKING: + from vllm.model_executor.models.deepseek_v2 import Indexer +logger = init_logger(__name__) + + +class ROCMAiterMLASparseBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA_SPARSE" + + @staticmethod + def get_metadata_cls() -> type[AttentionMetadata]: + return ROCMAiterMLASparseMetadata + + @staticmethod + def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]: + return ROCMAiterMLASparseMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]: + return ROCMAiterMLASparseImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + +@dataclass +class ROCMAiterMLASparseMetadata: + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + block_table: torch.Tensor + req_id_per_token: torch.Tensor + block_size: int = 1 + topk_tokens: int = 2048 + + +@dataclass +class ROCMAiterMLASparseMetadataBuilder( + AttentionMetadataBuilder[ROCMAiterMLASparseMetadata] +): + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER + + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.device = device + + self.num_heads = self.model_config.get_num_attention_heads(parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk + self.topk_tokens_tensor = torch.tensor( + [self.topk_tokens], device=device, dtype=torch.int32 + ) + self.max_model_len_tensor = torch.tensor( + [self.model_config.max_model_len], device=device, dtype=torch.int32 + ) + # this is ignored by `flash_mla_with_kvcache` if indices not None + self.dummy_block_table = torch.empty( + (1, 1), dtype=torch.int32, device=self.device + ) + + self.req_id_per_token_buffer = torch.empty( + (vllm_config.scheduler_config.max_num_batched_tokens,), + dtype=torch.int32, + device=device, + ) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> ROCMAiterMLASparseMetadata: + num_tokens = common_attn_metadata.num_actual_tokens + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths + ) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_( + torch.from_numpy(req_id_per_token), non_blocking=True + ) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + + metadata = ROCMAiterMLASparseMetadata( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + block_table=common_attn_metadata.block_table_tensor, + req_id_per_token=req_id_per_token, + block_size=self.kv_cache_spec.block_size, + topk_tokens=self.topk_tokens, + ) + return metadata + + +# Take from +# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_prefill.py#L72 +def reference_mla_sparse_prefill( + q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int +) -> tuple[torch.Tensor, torch.Tensor]: + import math + + def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor: + return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e) + + skv = kv.shape[0] + sq = q.shape[0] + topk = indices.shape[-1] + dqk = q.shape[-1] + indices = indices[:, 0, :] # [s_q, topk] + invalid_indices_mask = (indices < 0) | (indices >= skv) + indices[invalid_indices_mask] = 0 + qs = q # [s_q, h_q, d_qk] + kvs = kv[:, 0, :][indices].view(sq, topk, dqk) # [s_q, topk, d_qk] + + attn_score = (qs @ kvs.transpose(1, 2)).float() # [s_q, h_q, topk] + attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf")) + attn_score *= sm_scale * math.log2(math.e) + lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q] + attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk] + result = attn_score.to(q.dtype) @ kvs[:, :, :d_v] + return (result, lse) + + +class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: list[float] | None, + sliding_window: int | None, + kv_cache_dtype: str, + logits_soft_cap: float | None, + attn_type: str, + kv_sharing_target_layer_name: str | None, + # MLA Specific Arguments + topk_indice_buffer: torch.Tensor | None = None, + indexer: Optional["Indexer"] = None, + **mla_args, + ) -> None: + super().__init__( + num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window, + kv_cache_dtype, + logits_soft_cap, + attn_type, + kv_sharing_target_layer_name, + **mla_args, + ) + self.softmax_scale = scale + assert indexer is not None + self.topk_indices_buffer = indexer.topk_indices_buffer + self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() + + def _forward_bf16_kv( + self, + q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: ROCMAiterMLASparseMetadata, + ) -> torch.Tensor: + num_tokens = q.shape[0] + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, 1, kv_c_and_k_pe_cache.shape[-1] + ) + + topk_indices = topk_indices.view(num_tokens, 1, -1) + output = reference_mla_sparse_prefill( + q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512 + )[0] + return output[:, : self.num_heads, :] + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: ROCMAiterMLASparseMetadata, + output: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use + # MQA 576/512 approach for both prefill and decode + + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for ROCMAiterMLASparse" + ) + + if attn_metadata is None: + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + if self.is_fp8bmm_enabled: + # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) + ql_nope = rocm_aiter_ops.triton_fp8_bmm( + q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True + ) + else: + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + ql_nope = ql_nope.transpose(0, 1) + + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + topk_indices_global = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=attn_metadata.topk_tokens, + ) + + q = torch.cat([ql_nope, q_pe], dim=-1) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + attn_out = self._forward_bf16_kv( + q, kv_cache, topk_indices_global, attn_metadata + ) + + self._v_up_proj(attn_out, out=output[:num_actual_toks]) + return output diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 095407a8b9596..9e99ea964ee08 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -316,7 +316,7 @@ def bind_kv_cache( # TODO - analyze where runner_kv_caches is used and the right # way to ensure it properly reflects multiple attention layers # in the same decoder block. - if current_platform.is_cuda() or current_platform.is_xpu(): + if current_platform.is_cuda_alike() or current_platform.is_xpu(): # We know that the GPU runner is not impacted by this # case. Some test code depends on runner_kv_caches, but # not in a way that's impacted by ignoring this.