[Kernel] Unified Triton kernel that doesn't distinguish between prefill + decode (#16828)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Thomas Parnell 2025-05-06 18:21:48 -04:00 committed by GitHub
parent de906b95f9
commit 2f925e5777
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 566 additions and 27 deletions

View File

@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.platforms import current_platform
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16, 32]
DTYPES = [torch.float16, torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
def ref_paged_attn(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
query_lens: list[int],
kv_lens: list[int],
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs: list[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
block_indices = block_tables[i, :num_kv_blocks]
k = key_cache[block_indices].view(-1, num_kv_heads, head_size)
k = k[:kv_len]
v = value_cache[block_indices].view(-1, num_kv_heads, head_size)
v = v[:kv_len]
if q.shape[1] != k.shape[1]:
k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1)
v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1)
attn = torch.einsum("qhd,khd->hqk", q, k).float()
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None and soft_cap > 0:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
outputs.append(out)
start_idx += query_len
return torch.cat(outputs, dim=0)
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("q_dtype", QDTYPES)
@torch.inference_mode()
def test_triton_unified_attn(
seq_lens: list[tuple[int, int]],
num_heads: tuple[int, int],
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
q_dtype: Optional[torch.dtype],
) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
output = torch.empty_like(query)
maybe_quantized_query = query
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
q_descale = None
k_descale = None
v_descale = None
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
scale_shape = (num_seqs, num_kv_heads)
q_descale = None # Not yet supported
k_descale = torch.rand(scale_shape, dtype=torch.float32)
v_descale = torch.rand(scale_shape, dtype=torch.float32)
unified_attention(
q=maybe_quantized_query,
k=maybe_quantized_key_cache,
v=maybe_quantized_value_cache,
out=output,
cu_seqlens_q=cu_query_lens,
seqused_k=kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"

View File

@ -0,0 +1,333 @@
# SPDX-License-Identifier: Apache-2.0
# Authors:
# - Burkhard Ringlein <ngl@zurich.ibm.com>
# - Jan van Lunteren <jvl@zurich.ibm.com>
# - Chih-Chieh Yang <chih.chieh.yang@ibm.com>
# - Thomas Parnell <tpa@zurich.ibm.com>
import triton
import triton.language as tl
from vllm.logger import init_logger
logger = init_logger(__name__)
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def apply_softcap(S, x):
Sdiv = S / x
p1 = tl.exp(Sdiv)
p2 = tl.exp(-Sdiv)
return x * (p1 - p2) / (p1 + p2)
@triton.jit
def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
scale, # float32
k_scale, # float32
v_scale, # float32
softcap, # float32
num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int
block_table_stride: tl.int64, # int
query_stride_0: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_0: tl.int64, # int
output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
query_start_len_ptr, # [num_seqs+1]
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
left: tl.int32 = 0
right = num_seqs
while left < right:
mid = (left + right) // 2
mid_val = tl.load(query_start_len_ptr + mid) // BLOCK_Q + mid
if mid_val <= q_block_global_idx:
left = mid + 1
else:
right = mid
seq_idx = left - 1
q_block_start_idx = tl.load(query_start_len_ptr +
seq_idx) // BLOCK_Q + seq_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_query_len = cur_batch_in_all_stop_index \
- cur_batch_in_all_start_index
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + \
offs_m % num_queries_per_kv
query_offset = (query_offset_0[:, None] * query_stride_0 +
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1)
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
# Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,)
Q = tl.load(
query_ptr + query_offset,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
other=0.0,
)
block_table_offset = seq_idx * block_table_stride
M = tl.full([BLOCK_Q * num_queries_per_kv],
float("-inf"),
dtype=tl.float32)
L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED],
dtype=tl.float32)
# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
# context length for this particular sequences
context_len = seq_len - cur_batch_query_len
# alibi slope for this head
if USE_ALIBI_SLOPES:
alibi_slope = tl.load(alibi_slopes_ptr + query_offset_1,
mask=query_mask_1,
other=0.0)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# iterate through tiles
for j in range(0, num_blocks):
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
offs_n = tl.arange(0, BLOCK_SIZE)
v_offset = (physical_block_idx * stride_v_cache_0 +
kv_head_idx * stride_v_cache_2 +
offs_d[None, :] * stride_v_cache_3 +
offs_n[:, None] * stride_v_cache_1)
k_offset = (physical_block_idx * stride_k_cache_0 +
kv_head_idx * stride_k_cache_2 +
offs_d[:, None] * stride_k_cache_3 +
offs_n[None, :] * stride_k_cache_1)
# K : (HEAD_SIZE, BLOCK_SIZE)
K_load = tl.load(key_cache_ptr + k_offset,
mask=dim_mask[:, None],
other=0.0)
if K_load.dtype.is_fp8():
if Q.dtype.is_fp8():
K = K_load
else:
K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype)
else:
K = K_load
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(value_cache_ptr + v_offset,
mask=dim_mask[None, :],
other=0.0)
if V_load.dtype.is_fp8():
if Q.dtype.is_fp8():
V = V_load
else:
V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype)
else:
V = V_load
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE),
dtype=tl.float32)
S += scale * tl.dot(Q, K)
if USE_SOFTCAP:
S = apply_softcap(S, softcap)
S = tl.where(query_mask_1[:, None] & query_mask_0[:, None] & seq_mask,
S, float("-inf"))
if SLIDING_WINDOW > 0:
S = tl.where((context_len + query_pos[:, None] - seq_offset)
< SLIDING_WINDOW, S, float("-inf"))
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
# compute running maximum
# m_j : (BLOCK_Q * num_queries_per_kv,)
m_j = tl.maximum(M, tl.max(S, axis=1))
# For sliding window there's a chance the max is -inf due to masking of
# the entire row. In this case we need to set m_j 0 to avoid NaN
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
# P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
P = tl.exp(S - m_j[:, None])
# l_j : (BLOCK_Q * num_queries_per_kv,)
l_j = tl.sum(P, axis=1)
# alpha : (BLOCK_Q * num_queries_per_kv, )
alpha = tl.exp(M - m_j)
# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
acc = acc * alpha[:, None]
# update constants
L = L * alpha + l_j
M = m_j
# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
acc += tl.dot(P.to(V.dtype), V)
# epilogue
acc = acc / L[:, None]
output_offset = (query_offset_0[:, None] * output_stride_0 +
query_offset_1[:, None] * output_stride_1 +
offs_d[None, :])
tl.store(
output_ptr + output_offset,
acc,
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
)
def unified_attention(
q,
k,
v,
out,
cu_seqlens_q,
max_seqlen_q,
seqused_k,
max_seqlen_k,
softmax_scale,
causal,
window_size,
block_table,
softcap,
q_descale,
k_descale,
v_descale,
alibi_slopes=None,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
use_alibi_slopes = alibi_slopes is not None
block_size = v.shape[1]
num_seqs = len(seqused_k)
num_query_heads = q.shape[1]
num_kv_heads = k.shape[2]
num_queries_per_kv = num_query_heads // num_kv_heads
head_size = q.shape[2]
BLOCK_M = 16
BLOCK_Q = BLOCK_M // num_queries_per_kv
# Ideally we would launch with kernel with:
# \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks.
# However, it is slow to realize the query_lens on cpu.
# Instead we use upper-bound:
# \sum_i[ceil(query_len[i] / BLOCK_Q)]
# <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1]
# = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs
# <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
kernel_unified_attention_2d[(
total_num_q_blocks,
num_kv_heads,
)](
output_ptr=out,
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
scale=softmax_scale,
k_scale=k_descale,
v_scale=v_descale,
softcap=softcap,
num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv,
block_table_stride=block_table.stride(0),
query_stride_0=q.stride(0),
query_stride_1=q.stride(1),
output_stride_0=out.stride(0),
output_stride_1=out.stride(1),
BLOCK_SIZE=block_size,
HEAD_SIZE=head_size,
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
USE_ALIBI_SLOPES=use_alibi_slopes,
USE_SOFTCAP=(softcap > 0),
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
stride_k_cache_2=k.stride(2),
stride_k_cache_3=k.stride(3),
stride_v_cache_0=v.stride(0),
stride_v_cache_1=v.stride(1),
stride_v_cache_2=v.stride(2),
stride_v_cache_3=v.stride(3),
query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs,
)

View File

@ -4,11 +4,10 @@ from typing import Any, Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import (
FlashAttentionMetadata, FlashAttentionMetadataBuilder)
@ -87,6 +86,11 @@ class TritonAttentionImpl(AttentionImpl):
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap
self.use_irope = use_irope
assert self.num_heads % self.num_kv_heads == 0
@ -143,11 +147,9 @@ class TritonAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
# Reshape the input keys and values and store them in the cache.
PagedAttention.write_to_paged_cache(
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
@ -158,6 +160,18 @@ class TritonAttentionImpl(AttentionImpl):
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale == 1.0, \
"A non 1.0 q_scale is not currently supported."
query, _ = ops.scaled_fp8_quant(
query.reshape(
(num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
@ -165,34 +179,37 @@ class TritonAttentionImpl(AttentionImpl):
assert attn_metadata.local_attn_metadata is not None
local_metadata = attn_metadata.local_attn_metadata
cu_seqlens_q = local_metadata.local_query_start_loc
sequesd_k = local_metadata.local_seqused_k
seqused_k = local_metadata.local_seqused_k
max_seqlen_q = local_metadata.local_max_query_len
max_seqlen_k = local_metadata.local_max_seq_len
block_table = local_metadata.local_block_table
else:
cu_seqlens_q = attn_metadata.query_start_loc
sequesd_k = attn_metadata.seq_lens
seqused_k = attn_metadata.seq_lens
max_seqlen_q = attn_metadata.max_query_len
max_seqlen_k = attn_metadata.max_seq_len
block_table = attn_metadata.block_table
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=sequesd_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
out=output[:num_actual_tokens],
cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
seqused_k=seqused_k,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=True,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
block_table=block_table,
softcap=self.logits_soft_cap,
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
return output