mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[Kernel] Decouple Tile Size from Block Size in Triton Unified Attention Kernel (#21197)
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
This commit is contained in:
parent
bc19d75985
commit
01a583fea4
@ -102,9 +102,6 @@ def test_triton_unified_attn(
|
|||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32:
|
|
||||||
pytest.skip("block size must be at least 32 for fp8")
|
|
||||||
|
|
||||||
current_platform.seed_everything(0)
|
current_platform.seed_everything(0)
|
||||||
num_seqs = len(seq_lens)
|
num_seqs = len(seq_lens)
|
||||||
query_lens = [x[0] for x in seq_lens]
|
query_lens = [x[0] for x in seq_lens]
|
||||||
|
|||||||
@ -73,6 +73,7 @@ def kernel_unified_attention_2d(
|
|||||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
qq_bias_stride_0: tl.int64, # int
|
qq_bias_stride_0: tl.int64, # int
|
||||||
BLOCK_SIZE: tl.constexpr, # int
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
TILE_SIZE: tl.constexpr, # int must be power of 2
|
||||||
HEAD_SIZE: tl.constexpr, # int
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
@ -118,6 +119,7 @@ def kernel_unified_attention_2d(
|
|||||||
|
|
||||||
offs_m = tl.arange(0, BLOCK_M)
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
offs_t = tl.arange(0, TILE_SIZE)
|
||||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||||
|
|
||||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||||
@ -177,31 +179,32 @@ def kernel_unified_attention_2d(
|
|||||||
# actual sequence length
|
# actual sequence length
|
||||||
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||||
|
|
||||||
# calculate the number of tiles (blocks) that need to be processed to
|
# calculate the number of tiles that need to be processed to
|
||||||
# cover the longest sequence prefix (due to causal masking, blocks beyond
|
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||||
# this prefix can be skipped)
|
# this prefix can be skipped)
|
||||||
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
|
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||||
|
|
||||||
# iterate through tiles
|
# iterate through tiles
|
||||||
for j in range(0, num_blocks):
|
for j in range(0, num_tiles):
|
||||||
|
seq_offset = j * TILE_SIZE + offs_t
|
||||||
|
tile_mask = seq_offset < max_seq_prefix_len
|
||||||
|
|
||||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||||
|
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||||
|
|
||||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||||
|
|
||||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
|
||||||
kv_head_idx * stride_v_cache_2 +
|
kv_head_idx * stride_v_cache_2 +
|
||||||
offs_d[None, :] * stride_v_cache_3 +
|
offs_d[None, :] * stride_v_cache_3 +
|
||||||
offs_n[:, None] * stride_v_cache_1)
|
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||||
|
|
||||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||||
kv_head_idx * stride_k_cache_2 +
|
kv_head_idx * stride_k_cache_2 +
|
||||||
offs_d[:, None] * stride_k_cache_3 +
|
offs_d[:, None] * stride_k_cache_3 +
|
||||||
offs_n[None, :] * stride_k_cache_1)
|
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||||
|
|
||||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
# K : (HEAD_SIZE, TILE_SIZE)
|
||||||
K_load = tl.load(key_cache_ptr + k_offset,
|
K_load = tl.load(key_cache_ptr + k_offset,
|
||||||
mask=dim_mask[:, None],
|
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
if K_load.dtype.is_fp8():
|
if K_load.dtype.is_fp8():
|
||||||
@ -212,9 +215,9 @@ def kernel_unified_attention_2d(
|
|||||||
else:
|
else:
|
||||||
K = K_load
|
K = K_load
|
||||||
|
|
||||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
# V : (TILE_SIZE, HEAD_SIZE)
|
||||||
V_load = tl.load(value_cache_ptr + v_offset,
|
V_load = tl.load(value_cache_ptr + v_offset,
|
||||||
mask=dim_mask[None, :],
|
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
if V_load.dtype.is_fp8():
|
if V_load.dtype.is_fp8():
|
||||||
@ -225,12 +228,10 @@ def kernel_unified_attention_2d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
seq_offset = j * BLOCK_SIZE + offs_n
|
|
||||||
|
|
||||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||||
|
|
||||||
# S : (BLOCK_M, BLOCK_SIZE)
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
|
|
||||||
S += scale * tl.dot(Q, K)
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
@ -262,11 +263,12 @@ def kernel_unified_attention_2d(
|
|||||||
# compute running maximum
|
# compute running maximum
|
||||||
# m_j : (BLOCK_M,)
|
# m_j : (BLOCK_M,)
|
||||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
# For sliding window there's a chance the max is -inf due to masking of
|
# For sliding window there's a chance the max is -inf due to masking of
|
||||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||||
|
|
||||||
# P : (BLOCK_M, BLOCK_SIZE)
|
# P : (BLOCK_M, TILE_SIZE)
|
||||||
P = tl.exp(S - m_j[:, None])
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
# l_j : (BLOCK_M,)
|
# l_j : (BLOCK_M,)
|
||||||
@ -327,6 +329,7 @@ def kernel_unified_attention_3d(
|
|||||||
query_stride_1: tl.int64, # int, should be equal to head_size
|
query_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
qq_bias_stride_0: tl.int64, # int
|
qq_bias_stride_0: tl.int64, # int
|
||||||
BLOCK_SIZE: tl.constexpr, # int
|
BLOCK_SIZE: tl.constexpr, # int
|
||||||
|
TILE_SIZE: tl.constexpr, # int, must be power of 2
|
||||||
HEAD_SIZE: tl.constexpr, # int
|
HEAD_SIZE: tl.constexpr, # int
|
||||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
USE_ALIBI_SLOPES: tl.constexpr, # bool
|
||||||
@ -374,20 +377,19 @@ def kernel_unified_attention_3d(
|
|||||||
|
|
||||||
# number of segments for this particular sequence
|
# number of segments for this particular sequence
|
||||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||||
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
|
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||||
|
|
||||||
if segm_idx * blocks_per_segment * BLOCK_SIZE >= seq_len:
|
if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len:
|
||||||
return
|
return
|
||||||
|
|
||||||
offs_m = tl.arange(0, BLOCK_M)
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
|
||||||
|
offs_t = tl.arange(0, TILE_SIZE)
|
||||||
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv
|
||||||
|
|
||||||
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
query_offset_0 = cur_batch_in_all_start_index + query_pos
|
||||||
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
query_offset_1 = kv_head_idx * num_queries_per_kv + \
|
||||||
offs_m % num_queries_per_kv
|
offs_m % num_queries_per_kv
|
||||||
|
|
||||||
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
query_offset = (query_offset_0[:, None] * query_stride_0 +
|
||||||
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
query_offset_1[:, None] * query_stride_1 + offs_d[None, :])
|
||||||
|
|
||||||
@ -433,30 +435,44 @@ def kernel_unified_attention_3d(
|
|||||||
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
qq_bias_row_ptrs = (qq_bias_ptr + query_pos[:, None] * qq_bias_stride_0
|
||||||
) # shape: [BLOCK_M]
|
) # shape: [BLOCK_M]
|
||||||
|
|
||||||
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
|
# compute the length of the longest sequence prefix spanned by any
|
||||||
|
# query token in the current q_block (q_block_local_idx)
|
||||||
|
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
|
||||||
|
BLOCK_M - 1) // num_queries_per_kv + 1
|
||||||
|
|
||||||
|
# adjust for potential padding in the last q_block by considering the
|
||||||
|
# actual sequence length
|
||||||
|
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
|
||||||
|
|
||||||
|
# calculate the number of tiles that need to be processed to
|
||||||
|
# cover the longest sequence prefix (due to causal masking, tiles beyond
|
||||||
|
# this prefix can be skipped)
|
||||||
|
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||||
|
|
||||||
# iterate through tiles within current segment
|
# iterate through tiles within current segment
|
||||||
for j in range(
|
for j in range(
|
||||||
segm_idx * blocks_per_segment,
|
segm_idx * tiles_per_segment,
|
||||||
min((segm_idx + 1) * blocks_per_segment, num_blocks),
|
min((segm_idx + 1) * tiles_per_segment, num_tiles),
|
||||||
):
|
):
|
||||||
physical_block_idx = tl.load(block_tables_ptr + block_table_offset + j)
|
seq_offset = j * TILE_SIZE + offs_t
|
||||||
|
tile_mask = seq_offset < max_seq_prefix_len
|
||||||
|
|
||||||
offs_n = tl.arange(0, BLOCK_SIZE)
|
physical_block_idx = tl.load(block_tables_ptr + block_table_offset +
|
||||||
|
seq_offset // BLOCK_SIZE).to(tl.int64)
|
||||||
|
|
||||||
v_offset = (physical_block_idx * stride_v_cache_0 +
|
v_offset = (physical_block_idx[:, None] * stride_v_cache_0 +
|
||||||
kv_head_idx * stride_v_cache_2 +
|
kv_head_idx * stride_v_cache_2 +
|
||||||
offs_d[None, :] * stride_v_cache_3 +
|
offs_d[None, :] * stride_v_cache_3 +
|
||||||
offs_n[:, None] * stride_v_cache_1)
|
(seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1)
|
||||||
|
|
||||||
k_offset = (physical_block_idx * stride_k_cache_0 +
|
k_offset = (physical_block_idx[None, :] * stride_k_cache_0 +
|
||||||
kv_head_idx * stride_k_cache_2 +
|
kv_head_idx * stride_k_cache_2 +
|
||||||
offs_d[:, None] * stride_k_cache_3 +
|
offs_d[:, None] * stride_k_cache_3 +
|
||||||
offs_n[None, :] * stride_k_cache_1)
|
(seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1)
|
||||||
|
|
||||||
# K : (HEAD_SIZE, BLOCK_SIZE)
|
# K : (HEAD_SIZE, TILE_SIZE)
|
||||||
K_load = tl.load(key_cache_ptr + k_offset,
|
K_load = tl.load(key_cache_ptr + k_offset,
|
||||||
mask=dim_mask[:, None],
|
mask=dim_mask[:, None] & tile_mask[None, :],
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
if K_load.dtype.is_fp8():
|
if K_load.dtype.is_fp8():
|
||||||
@ -467,9 +483,9 @@ def kernel_unified_attention_3d(
|
|||||||
else:
|
else:
|
||||||
K = K_load
|
K = K_load
|
||||||
|
|
||||||
# V : (BLOCK_SIZE, HEAD_SIZE)
|
# V : (TILE_SIZE, HEAD_SIZE)
|
||||||
V_load = tl.load(value_cache_ptr + v_offset,
|
V_load = tl.load(value_cache_ptr + v_offset,
|
||||||
mask=dim_mask[None, :],
|
mask=dim_mask[None, :] & tile_mask[:, None],
|
||||||
other=0.0)
|
other=0.0)
|
||||||
|
|
||||||
if V_load.dtype.is_fp8():
|
if V_load.dtype.is_fp8():
|
||||||
@ -480,13 +496,10 @@ def kernel_unified_attention_3d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
seq_offset = j * BLOCK_SIZE + offs_n
|
|
||||||
|
|
||||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||||
|
|
||||||
# S : (BLOCK_M, BLOCK_SIZE)
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
S = tl.zeros(shape=(BLOCK_M, BLOCK_SIZE), dtype=tl.float32)
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
|
|
||||||
S += scale * tl.dot(Q, K)
|
S += scale * tl.dot(Q, K)
|
||||||
|
|
||||||
if USE_SOFTCAP:
|
if USE_SOFTCAP:
|
||||||
@ -517,11 +530,12 @@ def kernel_unified_attention_3d(
|
|||||||
# compute running maximum
|
# compute running maximum
|
||||||
# m_j : (BLOCK_M,)
|
# m_j : (BLOCK_M,)
|
||||||
m_j = tl.maximum(M, tl.max(S, axis=1))
|
m_j = tl.maximum(M, tl.max(S, axis=1))
|
||||||
|
|
||||||
# For sliding window there's a chance the max is -inf due to masking of
|
# For sliding window there's a chance the max is -inf due to masking of
|
||||||
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
# the entire row. In this case we need to set m_j 0 to avoid NaN
|
||||||
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
m_j = tl.where(m_j > float("-inf"), m_j, 0.0)
|
||||||
|
|
||||||
# P : (BLOCK_M, BLOCK_SIZE,)
|
# P : (BLOCK_M, TILE_SIZE,)
|
||||||
P = tl.exp(S - m_j[:, None])
|
P = tl.exp(S - m_j[:, None])
|
||||||
|
|
||||||
# l_j : (BLOCK_M,)
|
# l_j : (BLOCK_M,)
|
||||||
@ -573,7 +587,7 @@ def reduce_segments(
|
|||||||
output_stride_0: tl.int64, # int
|
output_stride_0: tl.int64, # int
|
||||||
output_stride_1: tl.int64, # int, should be equal to head_size
|
output_stride_1: tl.int64, # int, should be equal to head_size
|
||||||
block_table_stride: tl.int64, # int
|
block_table_stride: tl.int64, # int
|
||||||
BLOCK_SIZE: tl.constexpr, # int
|
TILE_SIZE: tl.constexpr, # int
|
||||||
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
HEAD_SIZE: tl.constexpr, # int, must be power of 2
|
||||||
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
|
||||||
query_start_len_ptr, # [num_seqs+1]
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
@ -594,10 +608,10 @@ def reduce_segments(
|
|||||||
|
|
||||||
# number of segments for this particular sequence
|
# number of segments for this particular sequence
|
||||||
num_segments = NUM_SEGMENTS_PER_SEQ
|
num_segments = NUM_SEGMENTS_PER_SEQ
|
||||||
blocks_per_segment = cdiv_fn(seq_len, num_segments * BLOCK_SIZE)
|
tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE)
|
||||||
|
|
||||||
# create masks for subsequent loads
|
# create masks for subsequent loads
|
||||||
act_num_segments = cdiv_fn(seq_len, blocks_per_segment * BLOCK_SIZE)
|
act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE)
|
||||||
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full(
|
||||||
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
|
[NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32)
|
||||||
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
dim_mask = tl.where(tl.arange(0, HEAD_SIZE_PADDED) < HEAD_SIZE, 1,
|
||||||
@ -671,13 +685,10 @@ def unified_attention(
|
|||||||
# Optional tensor for sinks
|
# Optional tensor for sinks
|
||||||
sinks=None,
|
sinks=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
assert causal, "Only causal attention is supported"
|
assert causal, "Only causal attention is supported"
|
||||||
assert q_descale is None, "Q scales not supported"
|
assert q_descale is None, "Q scales not supported"
|
||||||
|
|
||||||
block_size = v.shape[1]
|
|
||||||
assert q.element_size() >= 2 or block_size >= 32, \
|
|
||||||
"Block size must be at least 32 for fp8"
|
|
||||||
|
|
||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
assert sinks.shape[0] == q.shape[1], \
|
assert sinks.shape[0] == q.shape[1], \
|
||||||
"Sinks must be num_query_heads size"
|
"Sinks must be num_query_heads size"
|
||||||
@ -707,6 +718,12 @@ def unified_attention(
|
|||||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||||
|
|
||||||
|
# Assigning default tile sizes for prefill and decode.
|
||||||
|
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||||
|
# and at least 16 for all other data types.
|
||||||
|
TILE_SIZE_PREFILL = 32
|
||||||
|
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||||
|
|
||||||
# if batch contains a prefill
|
# if batch contains a prefill
|
||||||
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128:
|
||||||
kernel_unified_attention_2d[(
|
kernel_unified_attention_2d[(
|
||||||
@ -736,6 +753,7 @@ def unified_attention(
|
|||||||
output_stride_1=out.stride(1),
|
output_stride_1=out.stride(1),
|
||||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||||
BLOCK_SIZE=block_size,
|
BLOCK_SIZE=block_size,
|
||||||
|
TILE_SIZE=TILE_SIZE_PREFILL,
|
||||||
HEAD_SIZE=head_size,
|
HEAD_SIZE=head_size,
|
||||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
@ -809,6 +827,7 @@ def unified_attention(
|
|||||||
query_stride_1=q.stride(1),
|
query_stride_1=q.stride(1),
|
||||||
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0,
|
||||||
BLOCK_SIZE=block_size,
|
BLOCK_SIZE=block_size,
|
||||||
|
TILE_SIZE=TILE_SIZE_DECODE,
|
||||||
HEAD_SIZE=head_size,
|
HEAD_SIZE=head_size,
|
||||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
USE_ALIBI_SLOPES=use_alibi_slopes,
|
USE_ALIBI_SLOPES=use_alibi_slopes,
|
||||||
@ -830,7 +849,6 @@ def unified_attention(
|
|||||||
BLOCK_M=BLOCK_M,
|
BLOCK_M=BLOCK_M,
|
||||||
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
|
||||||
)
|
)
|
||||||
|
|
||||||
reduce_segments[(q.shape[0], num_query_heads)](
|
reduce_segments[(q.shape[0], num_query_heads)](
|
||||||
output_ptr=out,
|
output_ptr=out,
|
||||||
segm_output_ptr=segm_output,
|
segm_output_ptr=segm_output,
|
||||||
@ -844,7 +862,7 @@ def unified_attention(
|
|||||||
output_stride_0=out.stride(0),
|
output_stride_0=out.stride(0),
|
||||||
output_stride_1=out.stride(1),
|
output_stride_1=out.stride(1),
|
||||||
block_table_stride=block_table.stride(0),
|
block_table_stride=block_table.stride(0),
|
||||||
BLOCK_SIZE=block_size,
|
TILE_SIZE=TILE_SIZE_DECODE,
|
||||||
HEAD_SIZE=head_size,
|
HEAD_SIZE=head_size,
|
||||||
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
HEAD_SIZE_PADDED=triton.next_power_of_2(head_size),
|
||||||
query_start_len_ptr=cu_seqlens_q,
|
query_start_len_ptr=cu_seqlens_q,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user