mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 13:35:54 +08:00
[Attention] Blackwell FP8 MLA support with CUTLASS_MLA backend (#23289)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
731a6940e3
commit
a742322092
@ -64,11 +64,11 @@ struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
||||
template <typename T, typename TOut, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
using ElementOut = TOut;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
@ -178,7 +178,7 @@ typename T::Fmha::Arguments args_from_options(
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
||||
template <typename Element, typename ElementOut, bool IsPaged128, typename PersistenceOption>
|
||||
void runMla(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
@ -190,7 +190,7 @@ void runMla(
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits,
|
||||
cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||
using MlaSm100Type = MlaSm100<Element, ElementOut, IsPaged128, PersistenceOption>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||
|
||||
@ -233,13 +233,13 @@ void sm100_cutlass_mla_decode(
|
||||
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
|
||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
runMla<cutlass::half_t, cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
runMla<cutlass::bfloat16_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
runMla<cutlass::float_e4m3_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
@ -253,7 +253,7 @@ void sm100_cutlass_mla_decode(
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
|
||||
// which are float, so Element type here doesn't matter.
|
||||
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
|
||||
using MlaSm100Type = MlaSm100<cutlass::half_t, cutlass::half_t, true>;
|
||||
|
||||
// Get split kv. Requires problem shape and sm_count only.
|
||||
typename MlaSm100Type::Fmha::Arguments arguments;
|
||||
|
||||
@ -1,96 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="Cutlass MLA Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[
|
||||
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1,
|
||||
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q,
|
||||
kv,
|
||||
v,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
|
||||
@pytest.mark.parametrize("bs", [1, 2, 4])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("block_size", [16, 64, 128])
|
||||
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
|
||||
varlen: bool, block_size: int):
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.set_default_device('cuda')
|
||||
torch.manual_seed(42)
|
||||
|
||||
d = 576
|
||||
h_q = 128
|
||||
dv = 512
|
||||
|
||||
q_nope_dim = 128
|
||||
q_pe_dim = 64
|
||||
scale = (q_nope_dim + q_pe_dim)**(-0.5)
|
||||
if varlen:
|
||||
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
||||
seq_lens = seq_lens.clip(2).to(torch.int32)
|
||||
def cal_diff(x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||
(x * x + y * y).sum().item(), 1e-12)
|
||||
if (use_fp8):
|
||||
assert cos_diff < 1e-4
|
||||
else:
|
||||
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
|
||||
max_seq_len = seq_lens.max().item()
|
||||
block_num = (max_seq_len + block_size - 1) // block_size
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
# Pad block_num so that small blocks can be packed into full 128-sized
|
||||
# CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small
|
||||
# blocks.
|
||||
pack_factor = 128 // block_size
|
||||
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
|
||||
|
||||
# Amplify input values to ensure test coverage of edge cases where CUTLASS
|
||||
# kernel errors occur with split_k settings.
|
||||
q = torch.randn(bs, h_q, d) * 100
|
||||
block_table = torch.randint(0,
|
||||
bs * block_num, (bs, block_num),
|
||||
dtype=torch.int32)
|
||||
CUTLASS_MLA_UNSUPPORTED_REASON = \
|
||||
"Cutlass MLA Requires compute capability of 10 or above." \
|
||||
if not current_platform.is_device_capability(100) \
|
||||
else "Cutlass MLA is supported"
|
||||
|
||||
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
out_ans = torch.zeros_like(out_ref)
|
||||
q_nope = q[:, :, :dv].clone()
|
||||
q_pe = q[:, :, dv:].clone()
|
||||
ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens,
|
||||
block_table, scale)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(100),
|
||||
reason=CUTLASS_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("h_kv", [1])
|
||||
@pytest.mark.parametrize("d", [576])
|
||||
@pytest.mark.parametrize("dv", [512])
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float8_e4m3fn])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
causal, varlen, torch_dtype):
|
||||
device = torch.device("cuda:0")
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
else:
|
||||
init_dtype = torch_dtype
|
||||
torch.set_default_dtype(init_dtype)
|
||||
torch.set_default_device(device)
|
||||
torch.cuda.set_device(device)
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
|
||||
torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2)
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
|
||||
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
scale = math.sqrt(d)**(-1)
|
||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
|
||||
s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_table = torch.arange(b * max_seqlen_pad // block_size,
|
||||
dtype=torch.int32).view(
|
||||
b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
init_dtype = q.dtype
|
||||
if use_fp8:
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
descale_q = torch.ones((1), dtype=torch.float32)
|
||||
descale_k = torch.ones((1), dtype=torch.float32)
|
||||
|
||||
q = q.to(fp8_dtype)
|
||||
blocked_k = blocked_k.to(fp8_dtype)
|
||||
blocked_v = blocked_v.to(fp8_dtype)
|
||||
else:
|
||||
descale_q = None
|
||||
descale_k = None
|
||||
|
||||
def cutlass_mla():
|
||||
MAX_HEADS = 128
|
||||
|
||||
q_reshaped = q.squeeze(1)
|
||||
q_nope = q_reshaped[:, :, :dv].clone()
|
||||
q_pe = q_reshaped[:, :, dv:].clone()
|
||||
|
||||
if h_q < MAX_HEADS:
|
||||
q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv))
|
||||
q_nope_padded[:, :h_q] = q_nope
|
||||
q_nope = q_nope_padded
|
||||
|
||||
q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv))
|
||||
q_pe_padded[:, :h_q] = q_pe
|
||||
q_pe = q_pe_padded
|
||||
|
||||
kv_cache_flat = blocked_k.squeeze(2)
|
||||
device_properties = torch.cuda.get_device_properties(
|
||||
torch.device("cuda:0"))
|
||||
sm_count = device_properties.multi_processor_count
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seqlen * block_size, b, sm_count, num_kv_splits=1)
|
||||
workspace = torch.empty(workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
|
||||
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
|
||||
|
||||
ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat,
|
||||
cache_seqlens, block_table, workspace,
|
||||
scale, 1)
|
||||
return out_ans[:, :h_q].contiguous()
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
value = value.float()
|
||||
key = key.repeat_interleave(h_q // h_kv, dim=0)
|
||||
value = value.repeat_interleave(h_q // h_kv, dim=0)
|
||||
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
|
||||
if is_causal:
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k,
|
||||
dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
lse = attn_weight.logsumexp(dim=-1)
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
|
||||
return attn_weight @ value, lse
|
||||
|
||||
def ref_mla():
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||
blocked_k_ = (blocked_k.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_k
|
||||
blocked_v_ = (blocked_v.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_v
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
begin = i * max_seqlen_pad
|
||||
end = begin + cache_seqlens[i]
|
||||
out_i, lse_i = scaled_dot_product_attention(
|
||||
q_[i].transpose(0, 1),
|
||||
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
|
||||
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
|
||||
is_causal=causal,
|
||||
)
|
||||
out[i] = out_i.transpose(0, 1)
|
||||
lse[i] = lse_i
|
||||
return out, lse
|
||||
|
||||
out_cutlass = cutlass_mla()
|
||||
out_torch, lse_torch = ref_mla()
|
||||
# Extract the single token (s_q=1) slice to match cutlass output shape
|
||||
out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv]
|
||||
cal_diff(out_cutlass, out_torch_slice, "out", use_fp8)
|
||||
|
||||
t = triton.testing.do_bench(cutlass_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d +
|
||||
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
|
||||
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
|
||||
f"{bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
|
||||
@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
|
||||
else:
|
||||
attention_backend = "FLASHMLA"
|
||||
|
||||
# Only FlashMLA supports fp8
|
||||
if attention_backend == "FLASHMLA":
|
||||
# Only FlashMLA and CUTLASS_MLA support fp8
|
||||
if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]:
|
||||
supported = True
|
||||
else:
|
||||
supported = (not fp8_attention)
|
||||
|
||||
@ -108,10 +108,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self._use_old_cutlass_mla = False
|
||||
force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None)
|
||||
if force_old_cutlass:
|
||||
@ -182,11 +178,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
> 0), f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
# TODO(kaixih@nvidia): support fp8
|
||||
assert q_nope.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
|
||||
torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
|
||||
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got "
|
||||
f"{q_nope.dtype}.")
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert (
|
||||
seq_lens.dtype == torch.int32
|
||||
@ -195,7 +190,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
page_table.dtype == torch.int32
|
||||
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
|
||||
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q_nope.dtype)
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out,
|
||||
@ -220,9 +217,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
|
||||
# Adjust workspace size (if necessary)
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
@ -252,8 +246,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user