mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-09 15:08:01 +08:00
[Bugfix] Make condition in triton kernel constexpr (#22370)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
4a6b72c2ab
commit
2435ea7ed5
@ -60,6 +60,7 @@ def kernel_paged_attention_2d(
|
|||||||
stride_v_cache_3: tl.int64, # int
|
stride_v_cache_3: tl.int64, # int
|
||||||
filter_by_query_len: tl.constexpr, # bool
|
filter_by_query_len: tl.constexpr, # bool
|
||||||
query_start_len_ptr, # [num_seqs+1]
|
query_start_len_ptr, # [num_seqs+1]
|
||||||
|
USE_SINKS: tl.constexpr, # bool
|
||||||
):
|
):
|
||||||
seq_idx = tl.program_id(0)
|
seq_idx = tl.program_id(0)
|
||||||
kv_head_idx = tl.program_id(1)
|
kv_head_idx = tl.program_id(1)
|
||||||
@ -96,7 +97,7 @@ def kernel_paged_attention_2d(
|
|||||||
|
|
||||||
block_table_offset = seq_idx * block_table_stride
|
block_table_offset = seq_idx * block_table_stride
|
||||||
|
|
||||||
if sink_ptr is None:
|
if not USE_SINKS:
|
||||||
M = tl.full([num_queries_per_kv_padded],
|
M = tl.full([num_queries_per_kv_padded],
|
||||||
float("-inf"),
|
float("-inf"),
|
||||||
dtype=tl.float32)
|
dtype=tl.float32)
|
||||||
@ -386,4 +387,5 @@ def chunked_prefill_paged_decode(
|
|||||||
stride_v_cache_3=value_cache.stride(3),
|
stride_v_cache_3=value_cache.stride(3),
|
||||||
filter_by_query_len=True,
|
filter_by_query_len=True,
|
||||||
query_start_len_ptr=query_start_loc,
|
query_start_len_ptr=query_start_loc,
|
||||||
|
USE_SINKS=sinks is not None,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -81,6 +81,7 @@ def _fwd_kernel(Q,
|
|||||||
num_unroll_cache: tl.constexpr,
|
num_unroll_cache: tl.constexpr,
|
||||||
num_unroll_request: tl.constexpr,
|
num_unroll_request: tl.constexpr,
|
||||||
SKIP_DECODE: tl.constexpr,
|
SKIP_DECODE: tl.constexpr,
|
||||||
|
USE_SINKS: tl.constexpr,
|
||||||
MAX_Q_LEN: tl.constexpr = 0,
|
MAX_Q_LEN: tl.constexpr = 0,
|
||||||
MAX_CTX_LEN: tl.constexpr = 0):
|
MAX_CTX_LEN: tl.constexpr = 0):
|
||||||
|
|
||||||
@ -127,7 +128,7 @@ def _fwd_kernel(Q,
|
|||||||
other=0.0) # [M,D]
|
other=0.0) # [M,D]
|
||||||
|
|
||||||
# initialize pointer to m and l
|
# initialize pointer to m and l
|
||||||
if sink_ptr is None:
|
if not USE_SINKS:
|
||||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
else:
|
else:
|
||||||
m_i = tl.load(
|
m_i = tl.load(
|
||||||
@ -910,5 +911,6 @@ def context_attention_fwd(q,
|
|||||||
num_unroll_request=1,
|
num_unroll_request=1,
|
||||||
num_warps=4,
|
num_warps=4,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
|
USE_SINKS=sinks is not None,
|
||||||
**extra_kargs)
|
**extra_kargs)
|
||||||
return
|
return
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user