mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 15:45:34 +08:00
[Bugfix] fixes the causal_conv1d_update kernel update non-speculative decoding cases (#24680)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
40b6c9122b
commit
880c741bb6
@ -720,15 +720,15 @@ def _causal_conv1d_update_kernel(
|
|||||||
# STEP 2: assume state_len > seqlen
|
# STEP 2: assume state_len > seqlen
|
||||||
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||||
|
|
||||||
# The conv_state updates works in a sliding window manner,
|
# With speculative decoding, the conv_state updates works in a sliding
|
||||||
# at each forward pass, the tokens are shift by 1, so we
|
# window manner, at each forward pass, the tokens are shift by 1, so we
|
||||||
# load since idx_tokens + 1.
|
# load since idx_tokens + 1.
|
||||||
conv_state_ptrs_source = (
|
conv_state_ptrs_source = (
|
||||||
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) +
|
||||||
conv_state_token_offset * stride_conv_state_tok +
|
conv_state_token_offset * stride_conv_state_tok +
|
||||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||||
((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
|
||||||
) # [BLOCK_M, BLOCK_N]
|
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
|
||||||
mask = ((conv_state_batch_coord < num_cache_lines)
|
mask = ((conv_state_batch_coord < num_cache_lines)
|
||||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||||
& (idx_feats < dim)[None, :])
|
& (idx_feats < dim)[None, :])
|
||||||
@ -924,7 +924,10 @@ def causal_conv1d_update(
|
|||||||
)
|
)
|
||||||
stride_state_indices = conv_state_indices.stride(
|
stride_state_indices = conv_state_indices.stride(
|
||||||
0) if conv_state_indices is not None else 0
|
0) if conv_state_indices is not None else 0
|
||||||
|
if num_accepted_tokens is not None:
|
||||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||||
|
else:
|
||||||
|
state_len = width - 1
|
||||||
np2_statelen = triton.next_power_of_2(state_len)
|
np2_statelen = triton.next_power_of_2(state_len)
|
||||||
|
|
||||||
def grid(META):
|
def grid(META):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user