mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 16:59:22 +08:00
[Bugfix][ROCm] fix the power of 2 exception from triton_unified_attention.py when running llama4 models and unit test fix (#18100)
Signed-off-by: Hongxia Yang <hongxia.yang@amd.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
7951d78738
commit
269d901734
@ -13,7 +13,9 @@ HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
QDTYPES = [None, torch.float8_e4m3fn]
|
||||
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
|
||||
None, torch.float8_e4m3fnuz
|
||||
]
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
|
||||
@ -29,41 +29,42 @@ def apply_softcap(S, x):
|
||||
|
||||
@triton.jit
|
||||
def kernel_unified_attention_2d(
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
output_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
query_ptr, # [num_tokens, num_query_heads, head_size]
|
||||
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
|
||||
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
|
||||
seq_lens_ptr, # [num_seqs]
|
||||
alibi_slopes_ptr, # [num_query_heads]
|
||||
scale, # float32
|
||||
k_scale, # float32
|
||||
v_scale, # float32
|
||||
softcap, # float32
|
||||
num_query_heads: tl.constexpr, # int
|
||||
num_queries_per_kv: tl.constexpr, # int
|
||||
block_table_stride: tl.int64, # int
|
||||
query_stride_0: tl.int64, # int
|
||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||
output_stride_0: tl.int64, # int
|
||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||
BLOCK_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE: tl.constexpr, # int
|
||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
stride_k_cache_3: tl.constexpr, # int
|
||||
stride_v_cache_0: tl.int64, # int
|
||||
stride_v_cache_1: tl.int64, # int
|
||||
stride_v_cache_2: tl.int64, # int
|
||||
stride_v_cache_3: tl.constexpr, # int
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
BLOCK_Q: tl.constexpr, # int
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
):
|
||||
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
@ -94,15 +95,13 @@ def kernel_unified_attention_2d(
|
||||
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
|
||||
return
|
||||
|
||||
offs_m = tl.arange(0, BLOCK_Q * num_queries_per_kv)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||
|
||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||
|
||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||
offs_m % num_queries_per_kv
|
||||
|
||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||
|
||||
@ -110,7 +109,7 @@ def kernel_unified_attention_2d(
|
||||
query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1)
|
||||
query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1)
|
||||
|
||||
# Q : (BLOCK_Q * num_queries_per_kv, HEAD_SIZE,)
|
||||
# Q : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
Q = tl.load(
|
||||
query_ptr + query_offset,
|
||||
mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None],
|
||||
@ -119,12 +118,9 @@ def kernel_unified_attention_2d(
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
M = tl.full([BLOCK_Q * num_queries_per_kv],
|
||||
float("-inf"),
|
||||
dtype=tl.float32)
|
||||
L = tl.full([BLOCK_Q * num_queries_per_kv], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_Q * num_queries_per_kv, HEAD_SIZE_PADDED],
|
||||
dtype=tl.float32)
|
||||
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
|
||||
|
||||
# sequence len for this particular sequence
|
||||
seq_len = tl.load(seq_lens_ptr + seq_idx)
|
||||
@ -183,13 +179,12 @@ def kernel_unified_attention_2d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_offset = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
seq_offset = j * BLOCK_SIZE + offs_n
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
|
||||
# S : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
|
||||
S = tl.zeros(shape=(BLOCK_Q * num_queries_per_kv, BLOCK_SIZE),
|
||||
dtype=tl.float32)
|
||||
# S : (BLOCK_M, BLOCK_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
||||
|
||||
S += scale * tl.dot(Q, K)
|
||||
|
||||
@ -207,29 +202,29 @@ def kernel_unified_attention_2d(
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
# compute running maximum
|
||||
# m_j : (BLOCK_Q * num_queries_per_kv,)
|
||||
# m_j : (BLOCK_M,)
|
||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||
# For sliding window there's a chance the max is -inf due to masking of
|
||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||
|
||||
# P : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
|
||||
# P : (BLOCK_M, BLOCK_SIZE)
|
||||
P = tl.exp(S - m_j[:, None])
|
||||
|
||||
# l_j : (BLOCK_Q * num_queries_per_kv,)
|
||||
# l_j : (BLOCK_M,)
|
||||
l_j = tl.sum(P, axis=1)
|
||||
|
||||
# alpha : (BLOCK_Q * num_queries_per_kv, )
|
||||
# alpha : (BLOCK_M, )
|
||||
alpha = tl.exp(M - m_j)
|
||||
|
||||
# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc = acc * alpha[:, None]
|
||||
|
||||
# update constants
|
||||
L = L * alpha + l_j
|
||||
M = m_j
|
||||
|
||||
# acc : (BLOCK_Q * num_queries_per_kv, BLOCK_SIZE,)
|
||||
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
|
||||
acc += tl.dot(P.to(V.dtype), V)
|
||||
|
||||
# epilogue
|
||||
@ -334,4 +329,5 @@ def unified_attention(
|
||||
query_start_len_ptr=cu_seqlens_q,
|
||||
BLOCK_Q=BLOCK_Q,
|
||||
num_seqs=num_seqs,
|
||||
BLOCK_M=BLOCK_M,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user