mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:45:54 +08:00
[Qwen3-Next][GDN] fixes cuda graph capturing bug in GDN metadata and a stride bug in causal_conv_1d. (#25743)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
parent
6e30010d2f
commit
99b3a504c5
@ -41,6 +41,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
stride_istate_seq: tl.constexpr,
|
stride_istate_seq: tl.constexpr,
|
||||||
stride_istate_dim: tl.constexpr,
|
stride_istate_dim: tl.constexpr,
|
||||||
stride_istate_token: tl.constexpr,
|
stride_istate_token: tl.constexpr,
|
||||||
|
stride_cache_indices: tl.constexpr,
|
||||||
stride_o_seq: tl.constexpr,
|
stride_o_seq: tl.constexpr,
|
||||||
stride_o_dim: tl.constexpr,
|
stride_o_dim: tl.constexpr,
|
||||||
stride_o_token: tl.constexpr,
|
stride_o_token: tl.constexpr,
|
||||||
@ -69,7 +70,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
||||||
|
|
||||||
# single-sequence id
|
# single-sequence id
|
||||||
idx_seq = tl.load(batch_ptr + tl.program_id(0))
|
idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64)
|
||||||
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
|
||||||
|
|
||||||
# BLOCK_N elements along the feature-dimension (channel)
|
# BLOCK_N elements along the feature-dimension (channel)
|
||||||
@ -91,8 +92,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
|||||||
|
|
||||||
if IS_CONTINUOUS_BATCHING:
|
if IS_CONTINUOUS_BATCHING:
|
||||||
# cache_idx
|
# cache_idx
|
||||||
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(
|
conv_state_batch_coord = tl.load(conv_state_indices_ptr +
|
||||||
tl.int64)
|
idx_seq * stride_cache_indices).to(
|
||||||
|
tl.int64)
|
||||||
else:
|
else:
|
||||||
# cache_idx
|
# cache_idx
|
||||||
conv_state_batch_coord = idx_seq
|
conv_state_batch_coord = idx_seq
|
||||||
@ -480,6 +482,8 @@ def causal_conv1d_fn(
|
|||||||
stride_o_seq = out.stride(0)
|
stride_o_seq = out.stride(0)
|
||||||
stride_o_dim = out.stride(1)
|
stride_o_dim = out.stride(1)
|
||||||
stride_o_token = out.stride(2)
|
stride_o_token = out.stride(2)
|
||||||
|
stride_cache_indices = cache_indices.stride(
|
||||||
|
0) if cache_indices is not None else 0
|
||||||
|
|
||||||
if validate_data:
|
if validate_data:
|
||||||
assert x.dim() == 2
|
assert x.dim() == 2
|
||||||
@ -595,6 +599,7 @@ def causal_conv1d_fn(
|
|||||||
stride_istate_seq,
|
stride_istate_seq,
|
||||||
stride_istate_dim,
|
stride_istate_dim,
|
||||||
stride_istate_token,
|
stride_istate_token,
|
||||||
|
stride_cache_indices,
|
||||||
stride_o_seq,
|
stride_o_seq,
|
||||||
stride_o_dim,
|
stride_o_dim,
|
||||||
stride_o_token,
|
stride_o_token,
|
||||||
|
|||||||
@ -125,7 +125,7 @@ class GDNAttentionMetadataBuilder(
|
|||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||||
num_draft_tokens: Optional[torch.Tensor] = None,
|
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
|
||||||
fast_build: bool = False,
|
fast_build: bool = False,
|
||||||
) -> GDNAttentionMetadata:
|
) -> GDNAttentionMetadata:
|
||||||
m = common_attn_metadata
|
m = common_attn_metadata
|
||||||
@ -133,23 +133,25 @@ class GDNAttentionMetadataBuilder(
|
|||||||
query_start_loc = m.query_start_loc
|
query_start_loc = m.query_start_loc
|
||||||
context_lens = m.num_computed_tokens_cpu
|
context_lens = m.num_computed_tokens_cpu
|
||||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||||
seq_lens_tensor = m.seq_lens
|
|
||||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||||
|
|
||||||
if (not self.use_spec_decode or num_draft_tokens is None
|
if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
|
||||||
or num_draft_tokens.sum().item() == 0):
|
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
|
||||||
|
0].sum().item() == 0):
|
||||||
spec_sequence_masks = None
|
spec_sequence_masks = None
|
||||||
|
num_spec_decodes = 0
|
||||||
else:
|
else:
|
||||||
spec_sequence_masks = (num_draft_tokens > 0) & (
|
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||||
context_lens_tensor +
|
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||||
(num_draft_tokens + 1) == seq_lens_tensor)
|
if num_spec_decodes == 0:
|
||||||
if spec_sequence_masks.sum().item() == 0:
|
|
||||||
spec_sequence_masks = None
|
spec_sequence_masks = None
|
||||||
|
else:
|
||||||
|
spec_sequence_masks = spec_sequence_masks.to(
|
||||||
|
query_start_loc.device, non_blocking=True)
|
||||||
|
|
||||||
if spec_sequence_masks is None:
|
if spec_sequence_masks is None:
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
split_decodes_and_prefills(m, decode_threshold=1))
|
split_decodes_and_prefills(m, decode_threshold=1))
|
||||||
num_spec_decodes = 0
|
|
||||||
num_spec_decode_tokens = 0
|
num_spec_decode_tokens = 0
|
||||||
spec_token_masks = None
|
spec_token_masks = None
|
||||||
spec_state_indices_tensor = None
|
spec_state_indices_tensor = None
|
||||||
@ -158,7 +160,6 @@ class GDNAttentionMetadataBuilder(
|
|||||||
non_spec_query_start_loc = query_start_loc
|
non_spec_query_start_loc = query_start_loc
|
||||||
num_accepted_tokens = None
|
num_accepted_tokens = None
|
||||||
else:
|
else:
|
||||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
|
||||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||||
|
|
||||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||||
@ -314,28 +315,18 @@ class GDNAttentionMetadataBuilder(
|
|||||||
"""
|
"""
|
||||||
m = common_attn_metadata
|
m = common_attn_metadata
|
||||||
|
|
||||||
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
|
assert (
|
||||||
and ((m.num_reqs + 1) * (self.num_spec + 1)
|
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||||
>= m.num_actual_tokens)), \
|
and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
|
||||||
"GDN only supports decode-only full CUDAGraph capture. " \
|
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
f"Make sure batch size ({m.num_reqs}) <= "
|
||||||
|
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||||
|
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||||
|
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")
|
||||||
|
|
||||||
num_accepted_tokens = torch.full((m.num_reqs, ),
|
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||||
m.max_query_len,
|
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||||
dtype=torch.int32,
|
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||||
device=m.query_start_loc.device)
|
|
||||||
num_drafted_tokens = torch.full((m.num_reqs, ),
|
|
||||||
self.num_spec,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=m.query_start_loc.device)
|
|
||||||
|
|
||||||
# Fixes query-start loc for spec-sequence-indices.
|
return self.build(0, m, num_accepted_tokens,
|
||||||
m.query_start_loc = torch.arange(0,
|
num_decode_draft_tokens_cpu)
|
||||||
m.num_actual_tokens + 1,
|
|
||||||
step=m.max_query_len,
|
|
||||||
device=m.query_start_loc.device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
|
|
||||||
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
|
|
||||||
|
|
||||||
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)
|
|
||||||
|
|||||||
@ -360,8 +360,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
self.num_discarded_requests = 0
|
self.num_discarded_requests = 0
|
||||||
|
|
||||||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
|
|
||||||
@ -1103,17 +1103,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# Iterate over the dictionary rather than all requests since not all
|
# Iterate over the dictionary rather than all requests since not all
|
||||||
# requests have draft tokens.
|
# requests have draft tokens.
|
||||||
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
|
||||||
|
# For chunked prefills, use -1 as mask rather than 0, as guided
|
||||||
|
# decoding may rollback speculative tokens.
|
||||||
|
num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32)
|
||||||
for req_id, draft_token_ids in (
|
for req_id, draft_token_ids in (
|
||||||
scheduler_output.scheduled_spec_decode_tokens.items()):
|
scheduler_output.scheduled_spec_decode_tokens.items()):
|
||||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||||
|
num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if (
|
||||||
|
self.input_batch.num_computed_tokens_cpu[req_idx]
|
||||||
|
>= self.input_batch.num_prompt_tokens[req_idx]) else -1)
|
||||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||||
num_draft_tokens, cu_num_tokens)
|
num_draft_tokens, cu_num_tokens)
|
||||||
logits_indices = spec_decode_metadata.logits_indices
|
logits_indices = spec_decode_metadata.logits_indices
|
||||||
self.num_draft_tokens.np[:num_reqs] = num_draft_tokens
|
|
||||||
self.num_draft_tokens.np[num_reqs:].fill(0)
|
# For DECODE only cuda graph of some attention backends (e.g., GDN).
|
||||||
self.num_draft_tokens.copy_to_gpu()
|
self.num_decode_draft_tokens.np[:
|
||||||
|
num_reqs] = num_decode_draft_tokens
|
||||||
|
self.num_decode_draft_tokens.np[num_reqs:].fill(-1)
|
||||||
|
self.num_decode_draft_tokens.copy_to_gpu()
|
||||||
|
|
||||||
logits_indices_padded = None
|
logits_indices_padded = None
|
||||||
if self.cache_config.kv_sharing_fast_prefill:
|
if self.cache_config.kv_sharing_fast_prefill:
|
||||||
@ -1217,7 +1225,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
extra_attn_metadata_args = dict(
|
extra_attn_metadata_args = dict(
|
||||||
num_accepted_tokens=self.num_accepted_tokens.
|
num_accepted_tokens=self.num_accepted_tokens.
|
||||||
gpu[:num_reqs],
|
gpu[:num_reqs],
|
||||||
num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs],
|
num_decode_draft_tokens_cpu=self.
|
||||||
|
num_decode_draft_tokens.cpu[:num_reqs],
|
||||||
)
|
)
|
||||||
|
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user