[Bugfix] Make condition in triton kernel constexpr (#22370)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg 2025-08-06 13:00:58 -04:00 committed by GitHub
parent 4a6b72c2ab
commit 2435ea7ed5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 2 deletions

View File

@ -60,6 +60,7 @@ def kernel_paged_attention_2d(
stride_v_cache_3: tl.int64, # int
filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool
):
seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@ -96,7 +97,7 @@ def kernel_paged_attention_2d(
block_table_offset = seq_idx * block_table_stride
if sink_ptr is None:
if not USE_SINKS:
M = tl.full([num_queries_per_kv_padded],
float("-inf"),
dtype=tl.float32)
@ -386,4 +387,5 @@ def chunked_prefill_paged_decode(
stride_v_cache_3=value_cache.stride(3),
filter_by_query_len=True,
query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None,
)

View File

@ -81,6 +81,7 @@ def _fwd_kernel(Q,
num_unroll_cache: tl.constexpr,
num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr,
MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0):
@ -127,7 +128,7 @@ def _fwd_kernel(Q,
other=0.0) # [M,D]
# initialize pointer to m and l
if sink_ptr is None:
if not USE_SINKS:
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
m_i = tl.load(
@ -910,5 +911,6 @@ def context_attention_fwd(q,
num_unroll_request=1,
num_warps=4,
num_stages=1,
USE_SINKS=sinks is not None,
**extra_kargs)
return