[Qwen3-Next][GDN] fixes cuda graph capturing bug in GDN metadata and a stride bug in causal_conv_1d. (#25743)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He 2025-09-26 16:18:58 +08:00 committed by GitHub
parent 6e30010d2f
commit 99b3a504c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 43 deletions

View File

@ -41,6 +41,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
stride_istate_seq: tl.constexpr,
stride_istate_dim: tl.constexpr,
stride_istate_token: tl.constexpr,
stride_cache_indices: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
@ -69,7 +70,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# rather than mixing sequences - to make updating initial_states across sequences efficiently
# single-sequence id
idx_seq = tl.load(batch_ptr + tl.program_id(0))
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
# BLOCK_N elements along the feature-dimension (channel)
@ -91,7 +92,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices).to(
tl.int64)
else:
# cache_idx
@ -480,6 +482,8 @@ def causal_conv1d_fn(
stride_o_seq = out.stride(0)
stride_o_dim = out.stride(1)
stride_o_token = out.stride(2)
stride_cache_indices = cache_indices.stride(
0) if cache_indices is not None else 0
if validate_data:
assert x.dim() == 2
@ -595,6 +599,7 @@ def causal_conv1d_fn(
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_cache_indices,
stride_o_seq,
stride_o_dim,
stride_o_token,

View File

@ -125,7 +125,7 @@ class GDNAttentionMetadataBuilder(
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_draft_tokens: Optional[torch.Tensor] = None,
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
@ -133,23 +133,25 @@ class GDNAttentionMetadataBuilder(
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
seq_lens_tensor = m.seq_lens
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (not self.use_spec_decode or num_draft_tokens is None
or num_draft_tokens.sum().item() == 0):
if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
0].sum().item() == 0):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
else:
spec_sequence_masks = (num_draft_tokens > 0) & (
context_lens_tensor +
(num_draft_tokens + 1) == seq_lens_tensor)
if spec_sequence_masks.sum().item() == 0:
spec_sequence_masks = None
spec_sequence_masks = spec_sequence_masks.to(
query_start_loc.device, non_blocking=True)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1))
num_spec_decodes = 0
num_spec_decode_tokens = 0
spec_token_masks = None
spec_state_indices_tensor = None
@ -158,7 +160,6 @@ class GDNAttentionMetadataBuilder(
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
num_spec_decodes = spec_sequence_masks.sum().item()
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
@ -314,28 +315,18 @@ class GDNAttentionMetadataBuilder(
"""
m = common_attn_metadata
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
and ((m.num_reqs + 1) * (self.num_spec + 1)
>= m.num_actual_tokens)), \
"GDN only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
assert (
m.num_reqs <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
f"GDN only supports decode-only full CUDAGraph capture. "
f"Make sure batch size ({m.num_reqs}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
f"and number of tokens ({m.num_actual_tokens}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")
num_accepted_tokens = torch.full((m.num_reqs, ),
m.max_query_len,
dtype=torch.int32,
device=m.query_start_loc.device)
num_drafted_tokens = torch.full((m.num_reqs, ),
self.num_spec,
dtype=torch.int32,
device=m.query_start_loc.device)
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
# Fixes query-start loc for spec-sequence-indices.
m.query_start_loc = torch.arange(0,
m.num_actual_tokens + 1,
step=m.max_query_len,
device=m.query_start_loc.device,
dtype=torch.int32)
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)
return self.build(0, m, num_accepted_tokens,
num_decode_draft_tokens_cpu)

View File

@ -360,7 +360,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=torch.int64)
self.num_discarded_requests = 0
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
@ -1103,17 +1103,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Iterate over the dictionary rather than all requests since not all
# requests have draft tokens.
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
# For chunked prefills, use -1 as mask rather than 0, as guided
# decoding may rollback speculative tokens.
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
self.input_batch.num_computed_tokens_cpu[req_idx]
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens)
logits_indices = spec_decode_metadata.logits_indices
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
self.num_draft_tokens.np[num_reqs:].fill(0)
self.num_draft_tokens.copy_to_gpu()
# For DECODE only cuda graph of some attention backends (e.g., GDN).
self.num_decode_draft_tokens.np[:
num_reqs] = num_decode_draft_tokens
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
self.num_decode_draft_tokens.copy_to_gpu()
logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill:
@ -1217,7 +1225,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
extra_attn_metadata_args = dict(
num_accepted_tokens=self.num_accepted_tokens.
gpu[:num_reqs],
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
num_decode_draft_tokens_cpu=self.
num_decode_draft_tokens.cpu[:num_reqs],
)
if ubatch_slices is not None: