mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 01:39:37 +08:00
[Bugfix] Make condition in triton kernel constexpr (#22370)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
4a6b72c2ab
commit
2435ea7ed5
@ -60,6 +60,7 @@ def kernel_paged_attention_2d(
|
||||
stride_v_cache_3: tl.int64, # int
|
||||
filter_by_query_len: tl.constexpr, # bool
|
||||
query_start_len_ptr, # [num_seqs+1]
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
):
|
||||
seq_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@ -96,7 +97,7 @@ def kernel_paged_attention_2d(
|
||||
|
||||
block_table_offset = seq_idx * block_table_stride
|
||||
|
||||
if sink_ptr is None:
|
||||
if not USE_SINKS:
|
||||
M = tl.full([num_queries_per_kv_padded],
|
||||
float("-inf"),
|
||||
dtype=tl.float32)
|
||||
@ -386,4 +387,5 @@ def chunked_prefill_paged_decode(
|
||||
stride_v_cache_3=value_cache.stride(3),
|
||||
filter_by_query_len=True,
|
||||
query_start_len_ptr=query_start_loc,
|
||||
USE_SINKS=sinks is not None,
|
||||
)
|
||||
|
||||
@ -81,6 +81,7 @@ def _fwd_kernel(Q,
|
||||
num_unroll_cache: tl.constexpr,
|
||||
num_unroll_request: tl.constexpr,
|
||||
SKIP_DECODE: tl.constexpr,
|
||||
USE_SINKS: tl.constexpr,
|
||||
MAX_Q_LEN: tl.constexpr = 0,
|
||||
MAX_CTX_LEN: tl.constexpr = 0):
|
||||
|
||||
@ -127,7 +128,7 @@ def _fwd_kernel(Q,
|
||||
other=0.0) # [M,D]
|
||||
|
||||
# initialize pointer to m and l
|
||||
if sink_ptr is None:
|
||||
if not USE_SINKS:
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
else:
|
||||
m_i = tl.load(
|
||||
@ -910,5 +911,6 @@ def context_attention_fwd(q,
|
||||
num_unroll_request=1,
|
||||
num_warps=4,
|
||||
num_stages=1,
|
||||
USE_SINKS=sinks is not None,
|
||||
**extra_kargs)
|
||||
return
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user