mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
[Qwen3-Next][GDN] fixes cuda graph capturing bug in GDN metadata and a stride bug in causal_conv_1d. (#25743)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
parent
6e30010d2f
commit
99b3a504c5
@ -41,6 +41,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
stride_istate_seq: tl.constexpr,
|
||||
stride_istate_dim: tl.constexpr,
|
||||
stride_istate_token: tl.constexpr,
|
||||
stride_cache_indices: tl.constexpr,
|
||||
stride_o_seq: tl.constexpr,
|
||||
stride_o_dim: tl.constexpr,
|
||||
stride_o_token: tl.constexpr,
|
||||
@ -69,7 +70,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
||||
|
||||
# single-sequence id
|
||||
idx_seq = tl.load(batch_ptr + tl.program_id(0))
|
||||
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
|
||||
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
||||
|
||||
# BLOCK_N elements along the feature-dimension (channel)
|
||||
@ -91,7 +92,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
# cache_idx
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
|
||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices).to(
|
||||
tl.int64)
|
||||
else:
|
||||
# cache_idx
|
||||
@ -480,6 +482,8 @@ def causal_conv1d_fn(
|
||||
stride_o_seq = out.stride(0)
|
||||
stride_o_dim = out.stride(1)
|
||||
stride_o_token = out.stride(2)
|
||||
stride_cache_indices = cache_indices.stride(
|
||||
0) if cache_indices is not None else 0
|
||||
|
||||
if validate_data:
|
||||
assert x.dim() == 2
|
||||
@ -595,6 +599,7 @@ def causal_conv1d_fn(
|
||||
stride_istate_seq,
|
||||
stride_istate_dim,
|
||||
stride_istate_token,
|
||||
stride_cache_indices,
|
||||
stride_o_seq,
|
||||
stride_o_dim,
|
||||
stride_o_token,
|
||||
|
||||
@ -125,7 +125,7 @@ class GDNAttentionMetadataBuilder(
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
num_draft_tokens: Optional[torch.Tensor] = None,
|
||||
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
@ -133,23 +133,25 @@ class GDNAttentionMetadataBuilder(
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
seq_lens_tensor = m.seq_lens
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (not self.use_spec_decode or num_draft_tokens is None
|
||||
or num_draft_tokens.sum().item() == 0):
|
||||
if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
|
||||
0].sum().item() == 0):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
else:
|
||||
spec_sequence_masks = (num_draft_tokens > 0) & (
|
||||
context_lens_tensor +
|
||||
(num_draft_tokens + 1) == seq_lens_tensor)
|
||||
if spec_sequence_masks.sum().item() == 0:
|
||||
spec_sequence_masks = None
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
query_start_loc.device, non_blocking=True)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1))
|
||||
num_spec_decodes = 0
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_masks = None
|
||||
spec_state_indices_tensor = None
|
||||
@ -158,7 +160,6 @@ class GDNAttentionMetadataBuilder(
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
@ -314,28 +315,18 @@ class GDNAttentionMetadataBuilder(
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
|
||||
and ((m.num_reqs + 1) * (self.num_spec + 1)
|
||||
>= m.num_actual_tokens)), \
|
||||
"GDN only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
assert (
|
||||
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
|
||||
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||
f"Make sure batch size ({m.num_reqs}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")
|
||||
|
||||
num_accepted_tokens = torch.full((m.num_reqs, ),
|
||||
m.max_query_len,
|
||||
dtype=torch.int32,
|
||||
device=m.query_start_loc.device)
|
||||
num_drafted_tokens = torch.full((m.num_reqs, ),
|
||||
self.num_spec,
|
||||
dtype=torch.int32,
|
||||
device=m.query_start_loc.device)
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
|
||||
# Fixes query-start loc for spec-sequence-indices.
|
||||
m.query_start_loc = torch.arange(0,
|
||||
m.num_actual_tokens + 1,
|
||||
step=m.max_query_len,
|
||||
device=m.query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
|
||||
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)
|
||||
return self.build(0, m, num_accepted_tokens,
|
||||
num_decode_draft_tokens_cpu)
|
||||
|
||||
@ -360,7 +360,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=torch.int64)
|
||||
self.num_discarded_requests = 0
|
||||
|
||||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
@ -1103,17 +1103,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Iterate over the dictionary rather than all requests since not all
|
||||
# requests have draft tokens.
|
||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||
# For chunked prefills, use -1 as mask rather than 0, as guided
|
||||
# decoding may rollback speculative tokens.
|
||||
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
|
||||
for req_id, draft_token_ids in (
|
||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
|
||||
num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
|
||||
self.input_batch.num_computed_tokens_cpu[req_idx]
|
||||
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens)
|
||||
logits_indices = spec_decode_metadata.logits_indices
|
||||
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
||||
self.num_draft_tokens.np[num_reqs:].fill(0)
|
||||
self.num_draft_tokens.copy_to_gpu()
|
||||
|
||||
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||||
self.num_decode_draft_tokens.np[:
|
||||
num_reqs] = num_decode_draft_tokens
|
||||
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
||||
self.num_decode_draft_tokens.copy_to_gpu()
|
||||
|
||||
logits_indices_padded = None
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
@ -1217,7 +1225,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.
|
||||
gpu[:num_reqs],
|
||||
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
|
||||
num_decode_draft_tokens_cpu=self.
|
||||
num_decode_draft_tokens.cpu[:num_reqs],
|
||||
)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user