[Bugfix][Qwen3-Next] fixes the varlen issue in qwen3-next's MTP implementation. (#24957)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He 2025-09-17 21:59:09 +08:00 committed by GitHub
parent 1b962e2457
commit dd6a910aac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 139 additions and 34 deletions

View File

@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel(
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel(
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
@ -678,6 +680,25 @@ def _causal_conv1d_update_kernel(
# not processing as this is not the actual sequence
return
if IS_VARLEN:
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
tl.int64)
# revise state_len and seqlen
state_len = state_len - (seqlen -
(query_end_index - query_start_index))
seqlen = query_end_index - query_start_index
x_offset = query_start_index * stride_x_token
o_offset = query_start_index * stride_o_token
else:
query_start_index = idx_seq * seqlen
query_end_index = query_start_index + seqlen
x_offset = idx_seq * stride_x_seq
o_offset = idx_seq * stride_o_seq
if query_start_index == query_end_index:
return
if IS_SPEC_DECODING:
# The rolling of conv state:
#
@ -692,8 +713,8 @@ def _causal_conv1d_update_kernel(
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
1)
conv_state_token_offset = (
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1)
else:
conv_state_token_offset = 0
@ -713,9 +734,12 @@ def _causal_conv1d_update_kernel(
if KERNEL_WIDTH >= 4:
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 5:
if KERNEL_WIDTH >= 5:
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 6:
conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
col4 = tl.load(conv_states_ptrs, mask_w, 0.0)
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
@ -735,8 +759,7 @@ def _causal_conv1d_update_kernel(
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
) # [BLOCK_N]
x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N]
x_ptrs = x_base[None, :] + (
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
@ -782,12 +805,18 @@ def _causal_conv1d_update_kernel(
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 5:
w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor
w_col4 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 6:
w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor
w_col5 = tl.load(w_ptrs, mask_w, other=0.0)
x_base_1d = x_base # starting of chunk [BLOCK_N]
mask_x_1d = idx_feats < dim
# STEP 5: compute each token
for idx_token in tl.static_range(seqlen):
for idx_token in tl.range(seqlen):
acc = acc_preload
matrix_w = w_col0
@ -817,6 +846,37 @@ def _causal_conv1d_update_kernel(
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 5:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
matrix_x = col3
elif j == 4:
matrix_w = w_col4
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 6:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
matrix_x = col3
elif j == 4:
matrix_w = w_col4
matrix_x = col4
elif j == 5:
matrix_w = w_col5
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
@ -829,14 +889,24 @@ def _causal_conv1d_update_kernel(
col0 = col1
col1 = col2
col2 = matrix_x
elif KERNEL_WIDTH == 5:
col0 = col1
col1 = col2
col2 = col3
col3 = matrix_x
elif KERNEL_WIDTH == 6:
col0 = col1
col1 = col2
col2 = col3
col3 = col4
col4 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < seqlen) & (idx_feats < dim
) # token-index # feature-index
o_ptrs = o_ptr + (
idx_seq) * stride_o_seq + idx_token * stride_o_token + (
idx_feats * stride_o_dim)
o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
stride_o_dim)
tl.store(o_ptrs, acc, mask=mask_1d)
@ -850,14 +920,18 @@ def causal_conv1d_update(
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
query_start_loc: Optional[torch.Tensor] = None,
max_query_len: int = -1,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim)
[shape=2: single token prediction]
[shape=3: single or multiple tokens prediction]
[shape=2 with num_tokens: continuous batching, where num_tokens is the
total tokens of all sequences in that batch]
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
@ -870,13 +944,24 @@ def causal_conv1d_update(
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
num_accepted_tokens: (batch,), dtype int32
If not None, it indicates the number of accepted tokens for each
sequence in the batch.
This is used in speculative decoding, where the conv_state is updated
in a sliding window manner.
query_start_loc: (batch + 1,) int32
If not None, the inputs is given in a varlen fashion and this indicates
the starting index of each sequence in the batch.
max_query_len: int
If query_start_loc is not None, this indicates the maximum query
length in the batch.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
"""
if validate_data:
assert cache_seqlens is None # not implemented yet - ok for vLLM
@ -886,11 +971,17 @@ def causal_conv1d_update(
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
unsqueeze = query_start_loc is None and x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
if query_start_loc is None:
batch, dim, seqlen = x.shape
else:
assert conv_state_indices is not None
batch = conv_state_indices.size(0)
dim = x.size(1)
seqlen = max_query_len
_, width = weight.shape
# conv_state: (..., dim, state_len), where state_len >= width - 1
num_cache_lines, _, state_len = conv_state.size()
@ -916,10 +1007,17 @@ def causal_conv1d_update(
out = x
stride_w_dim, stride_w_width = weight.stride()
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
) # X (batch, dim, seqlen)
if query_start_loc is None:
# X (batch, dim, seqlen)
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
else:
# X (dim, cu_seqlen)
stride_x_token, stride_x_dim = x.stride()
stride_x_seq = 0
stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
)
stride_state_indices = conv_state_indices.stride(
@ -945,6 +1043,7 @@ def causal_conv1d_update(
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
query_start_loc,
out,
# Matrix dimensions
batch,
@ -971,6 +1070,7 @@ def causal_conv1d_update(
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_VARLEN=query_start_loc is not None,
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,

View File

@ -417,9 +417,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = (attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens +
attn_metadata.num_spec_decode_tokens)
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens
# 1. Set up dimensions for reshapes later
@ -458,9 +456,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 2.1: process the mutli-query part
if spec_sequence_masks is not None:
mixed_qkv_spec = mixed_qkv_spec.view(
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
conv_state,
@ -470,9 +465,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
conv_state_indices=spec_state_indices_tensor[:, 0]
[:attn_metadata.num_spec_decodes],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
max_query_len=spec_state_indices_tensor.size(-1),
validate_data=False,
)
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:

View File

@ -31,6 +31,7 @@ class GDNAttentionMetadata:
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: Optional[torch.Tensor] = None
@ -74,8 +75,8 @@ class GDNAttentionMetadataBuilder(
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.vllm_config.scheduler_config.max_num_seqs *
(self.num_spec + 1), self.compilation_config.max_capture_size)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
@ -194,9 +195,8 @@ class GDNAttentionMetadataBuilder(
dim=0,
out=non_spec_query_start_loc[1:])
num_spec_decode_tokens = min(
num_spec_decodes * (self.num_spec + 1),
spec_token_masks.size(0))
num_spec_decode_tokens = (query_lens.sum().item() -
num_prefill_tokens - num_decode_tokens)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
@ -206,14 +206,22 @@ class GDNAttentionMetadataBuilder(
has_initial_state = has_initial_state[~spec_sequence_masks]
else:
has_initial_state = None
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
num_spec_decode_tokens
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens // (self.num_spec + 1)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True)
@ -229,7 +237,7 @@ class GDNAttentionMetadataBuilder(
assert spec_token_masks is not None
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True)
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
spec_token_masks[spec_token_masks.size(0):].fill_(False)
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
@ -248,9 +256,9 @@ class GDNAttentionMetadataBuilder(
if (self.use_full_cuda_graph and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True)
@ -274,6 +282,7 @@ class GDNAttentionMetadataBuilder(
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,