mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 12:24:29 +08:00
[ROCm] Add AMD GPU support on Deepseek v3.2 and SparseMLA (#26670)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
parent
6eb745d9bd
commit
06c20c9904
@ -552,7 +552,11 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
||||
#ifndef USE_ROCM
|
||||
__syncwarp();
|
||||
#endif
|
||||
#if defined(__gfx942__)
|
||||
float scale = fmaxf(amax, 1e-4) / 224.0f;
|
||||
#else
|
||||
float scale = fmaxf(amax, 1e-4) / 448.0f;
|
||||
#endif
|
||||
if (use_ue8m0) {
|
||||
scale = exp2f(ceilf(log2f(scale)));
|
||||
}
|
||||
|
||||
210
vllm/attention/ops/rocm_aiter_mla_sparse.py
Normal file
210
vllm/attention/ops/rocm_aiter_mla_sparse.py
Normal file
@ -0,0 +1,210 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
|
||||
def fp8_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
kv, scale = kv
|
||||
seq_len_kv = kv.shape[0]
|
||||
k = kv.to(torch.bfloat16)
|
||||
q = q.to(torch.bfloat16)
|
||||
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits for a single sequence without KV paging.
|
||||
|
||||
Args:
|
||||
q: Query tensor of shape [M, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
|
||||
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
|
||||
[N, 1]) with dtype `torch.float32`.
|
||||
weights: weights of shape [M, H], dtype `torch.float32`.
|
||||
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
|
||||
shape [M], dtype int32.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
|
||||
# TODO(ganyi): Temporarily workaround, will remove the module check and reference
|
||||
# path after aiter merge this kernel into main
|
||||
@lru_cache
|
||||
def has_mqa_logits_module():
|
||||
return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None
|
||||
|
||||
if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
|
||||
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits
|
||||
|
||||
kv, scale = kv
|
||||
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
else:
|
||||
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
|
||||
def fp8_paged_mqa_logits_torch(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
from vllm.utils.math_utils import cdiv
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
batch_size, next_n, _, dim = q.size()
|
||||
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
|
||||
scale = scale.contiguous().view(torch.float)
|
||||
q = q.float()
|
||||
kv_cache = kv_cache.view(fp8_dtype).float() * scale
|
||||
num_block, block_size, _, dim = kv_cache.size()
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
context_lens = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens[i]
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
||||
weight_slice = (
|
||||
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
)
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(
|
||||
block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (
|
||||
k_offsets[None, :] <= q_offsets[:, None]
|
||||
)
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype
|
||||
),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n : (i + 1) * next_n,
|
||||
block_rk * block_size : (block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def rocm_fp8_paged_mqa_logits(
|
||||
q_fp8: torch.Tensor,
|
||||
kv_cache_fp8: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
schedule_metadata: torch.Tensor,
|
||||
max_model_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Compute FP8 MQA logits using paged KV-cache.
|
||||
|
||||
Args:
|
||||
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
|
||||
`torch.float8_e4m3fn` by caller.
|
||||
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
|
||||
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
|
||||
4 bytes per (block,pos) store the `float` dequant scale.
|
||||
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
|
||||
context_lens: Tensor of shape [B], dtype int32; effective context length
|
||||
for each batch element.
|
||||
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
|
||||
block indices to physical blocks in the paged cache.
|
||||
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
|
||||
used to distribute work across SMs.
|
||||
max_model_len: Maximum sequence length used to size the logits output.
|
||||
|
||||
Returns:
|
||||
Logits tensor of shape [B * next_n, max_model_len], dtype
|
||||
`torch.float32`.
|
||||
"""
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1
|
||||
|
||||
batch_size, next_n, heads, _ = q_fp8.shape
|
||||
out_qk = torch.full(
|
||||
(heads, batch_size * next_n, max_model_len),
|
||||
float("-inf"),
|
||||
device="cuda",
|
||||
dtype=torch.float32,
|
||||
)
|
||||
deepgemm_fp8_paged_mqa_logits_stage1(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
weights,
|
||||
out_qk,
|
||||
context_lens,
|
||||
block_tables,
|
||||
max_model_len,
|
||||
)
|
||||
return out_qk.sum(dim=0)
|
||||
else:
|
||||
return fp8_paged_mqa_logits_torch(
|
||||
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
|
||||
)
|
||||
@ -594,6 +594,7 @@ def sparse_attn_indexer(
|
||||
) -> torch.Tensor:
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return sparse_attn_indexer_fake(
|
||||
@ -633,7 +634,7 @@ def sparse_attn_indexer(
|
||||
k_fp8 = torch.empty(
|
||||
[chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
dtype=fp8_dtype,
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
[chunk.total_seq_lens, 4],
|
||||
@ -647,7 +648,12 @@ def sparse_attn_indexer(
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
)
|
||||
logits = fp8_mqa_logits(
|
||||
fp8_mqa_logits_func = fp8_mqa_logits
|
||||
if current_platform.is_rocm():
|
||||
from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits
|
||||
|
||||
fp8_mqa_logits_func = rocm_fp8_mqa_logits
|
||||
logits = fp8_mqa_logits_func(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32)),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
@ -692,7 +698,14 @@ def sparse_attn_indexer(
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
logits = fp8_paged_mqa_logits(
|
||||
fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
|
||||
if current_platform.is_rocm():
|
||||
from vllm.attention.ops.rocm_aiter_mla_sparse import (
|
||||
rocm_fp8_paged_mqa_logits,
|
||||
)
|
||||
|
||||
fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
|
||||
logits = fp8_paged_mqa_logits_func(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
@ -749,7 +762,8 @@ def sparse_attn_indexer_fake(
|
||||
_flattened_kv = torch.empty(
|
||||
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
|
||||
)
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
|
||||
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
@ -225,7 +225,18 @@ class RocmPlatform(Platform):
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
raise ValueError(
|
||||
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
|
||||
)
|
||||
assert block_size == 1, (
|
||||
"Sparse MLA backend on ROCm only supports block size 1 for now."
|
||||
)
|
||||
logger.info_once("Using Sparse MLA backend on V1 engine.")
|
||||
return (
|
||||
"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse."
|
||||
"ROCMAiterMLASparseBackend"
|
||||
)
|
||||
|
||||
if use_mla:
|
||||
if selected_backend is None:
|
||||
|
||||
@ -325,6 +325,7 @@ DEFAULT_BLOCK_SIZE = [128, 128]
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
block_m, block_n = block_size
|
||||
@ -334,9 +335,9 @@ def per_block_cast_to_fp8(
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0
|
||||
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2)
|
||||
)
|
||||
|
||||
@ -168,7 +168,7 @@ def _convert_req_index_to_global_index_kernel(
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = block_id < max_num_blocks_per_req
|
||||
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
base = tl.load(bt_ptr, mask=valid_block, other=0)
|
||||
|
||||
|
||||
@ -11,7 +11,8 @@ from vllm.attention.backends.abstract import (
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
@ -23,7 +24,9 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [
|
||||
1 if current_platform.is_rocm() else 64
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
@ -328,10 +331,10 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
if is_deep_gemm_supported():
|
||||
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
)
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
|
||||
325
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
Normal file
325
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
Normal file
@ -0,0 +1,325 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBaseImpl,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ROCMAiterMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type[AttentionMetadata]:
|
||||
return ROCMAiterMLASparseMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["ROCMAiterMLASparseMetadataBuilder"]:
|
||||
return ROCMAiterMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["ROCMAiterMLASparseImpl"]:
|
||||
return ROCMAiterMLASparseImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCMAiterMLASparseMetadata:
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
block_size: int = 1
|
||||
topk_tokens: int = 2048
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCMAiterMLASparseMetadataBuilder(
|
||||
AttentionMetadataBuilder[ROCMAiterMLASparseMetadata]
|
||||
):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.NEVER
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.topk_tokens_tensor = torch.tensor(
|
||||
[self.topk_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
self.max_model_len_tensor = torch.tensor(
|
||||
[self.model_config.max_model_len], device=device, dtype=torch.int32
|
||||
)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty(
|
||||
(1, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> ROCMAiterMLASparseMetadata:
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
metadata = ROCMAiterMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
# Take from
|
||||
# https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla_prefill.py#L72
|
||||
def reference_mla_sparse_prefill(
|
||||
q: torch.Tensor, kv: torch.Tensor, indices: torch.Tensor, sm_scale: float, d_v: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import math
|
||||
|
||||
def log2sumexp2(a: torch.Tensor, dim: int) -> torch.Tensor:
|
||||
return torch.logsumexp(a * math.log(2), dim=dim) * math.log2(math.e)
|
||||
|
||||
skv = kv.shape[0]
|
||||
sq = q.shape[0]
|
||||
topk = indices.shape[-1]
|
||||
dqk = q.shape[-1]
|
||||
indices = indices[:, 0, :] # [s_q, topk]
|
||||
invalid_indices_mask = (indices < 0) | (indices >= skv)
|
||||
indices[invalid_indices_mask] = 0
|
||||
qs = q # [s_q, h_q, d_qk]
|
||||
kvs = kv[:, 0, :][indices].view(sq, topk, dqk) # [s_q, topk, d_qk]
|
||||
|
||||
attn_score = (qs @ kvs.transpose(1, 2)).float() # [s_q, h_q, topk]
|
||||
attn_score.masked_fill_(invalid_indices_mask.unsqueeze(1), float("-inf"))
|
||||
attn_score *= sm_scale * math.log2(math.e)
|
||||
lse = log2sumexp2(attn_score, dim=-1) # [s_q, h_q]
|
||||
attn_score = torch.exp2(attn_score - lse.unsqueeze(-1)) # [s_q, h_q, topk]
|
||||
result = attn_score.to(q.dtype) @ kvs[:, :, :d_v]
|
||||
return (result, lse)
|
||||
|
||||
|
||||
class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: ROCMAiterMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = reference_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale, 512
|
||||
)[0]
|
||||
return output[:, : self.num_heads, :]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCMAiterMLASparseMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for ROCMAiterMLASparse"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
return output.fill_(0)
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
q_nope = q_nope.transpose(0, 1)
|
||||
if self.is_fp8bmm_enabled:
|
||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||
ql_nope = rocm_aiter_ops.triton_fp8_bmm(
|
||||
q_nope, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
|
||||
)
|
||||
else:
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
ql_nope = torch.bmm(q_nope, self.W_UK_T)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
return output
|
||||
@ -316,7 +316,7 @@ def bind_kv_cache(
|
||||
# TODO - analyze where runner_kv_caches is used and the right
|
||||
# way to ensure it properly reflects multiple attention layers
|
||||
# in the same decoder block.
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
# We know that the GPU runner is not impacted by this
|
||||
# case. Some test code depends on runner_kv_caches, but
|
||||
# not in a way that's impacted by ignoring this.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user