diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 010fcdda156c..5e5011fa2ac5 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -41,6 +41,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching stride_istate_seq: tl.constexpr, stride_istate_dim: tl.constexpr, stride_istate_token: tl.constexpr, + stride_cache_indices: tl.constexpr, stride_o_seq: tl.constexpr, stride_o_dim: tl.constexpr, stride_o_token: tl.constexpr, @@ -69,7 +70,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching # rather than mixing sequences - to make updating initial_states across sequences efficiently # single-sequence id - idx_seq = tl.load(batch_ptr + tl.program_id(0)) + idx_seq = tl.load(batch_ptr + tl.program_id(0)).to(tl.int64) chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) # BLOCK_N elements along the feature-dimension (channel) @@ -91,8 +92,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching if IS_CONTINUOUS_BATCHING: # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( - tl.int64) + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_cache_indices).to( + tl.int64) else: # cache_idx conv_state_batch_coord = idx_seq @@ -480,6 +482,8 @@ def causal_conv1d_fn( stride_o_seq = out.stride(0) stride_o_dim = out.stride(1) stride_o_token = out.stride(2) + stride_cache_indices = cache_indices.stride( + 0) if cache_indices is not None else 0 if validate_data: assert x.dim() == 2 @@ -595,6 +599,7 @@ def causal_conv1d_fn( stride_istate_seq, stride_istate_dim, stride_istate_token, + stride_cache_indices, stride_o_seq, stride_o_dim, stride_o_token, diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 843958bc79de..11f165d6cfc6 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -125,7 +125,7 @@ class GDNAttentionMetadataBuilder( common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, num_accepted_tokens: Optional[torch.Tensor] = None, - num_draft_tokens: Optional[torch.Tensor] = None, + num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None, fast_build: bool = False, ) -> GDNAttentionMetadata: m = common_attn_metadata @@ -133,23 +133,25 @@ class GDNAttentionMetadataBuilder( query_start_loc = m.query_start_loc context_lens = m.num_computed_tokens_cpu context_lens_tensor = context_lens.to(query_start_loc.device) - seq_lens_tensor = m.seq_lens nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None - if (not self.use_spec_decode or num_draft_tokens is None - or num_draft_tokens.sum().item() == 0): + if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None + or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= + 0].sum().item() == 0): spec_sequence_masks = None + num_spec_decodes = 0 else: - spec_sequence_masks = (num_draft_tokens > 0) & ( - context_lens_tensor + - (num_draft_tokens + 1) == seq_lens_tensor) - if spec_sequence_masks.sum().item() == 0: + spec_sequence_masks = num_decode_draft_tokens_cpu >= 0 + num_spec_decodes = spec_sequence_masks.sum().item() + if num_spec_decodes == 0: spec_sequence_masks = None + else: + spec_sequence_masks = spec_sequence_masks.to( + query_start_loc.device, non_blocking=True) if spec_sequence_masks is None: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills(m, decode_threshold=1)) - num_spec_decodes = 0 num_spec_decode_tokens = 0 spec_token_masks = None spec_state_indices_tensor = None @@ -158,7 +160,6 @@ class GDNAttentionMetadataBuilder( non_spec_query_start_loc = query_start_loc num_accepted_tokens = None else: - num_spec_decodes = spec_sequence_masks.sum().item() query_lens = query_start_loc[1:] - query_start_loc[:-1] non_spec_query_lens = query_lens[~spec_sequence_masks] @@ -314,28 +315,18 @@ class GDNAttentionMetadataBuilder( """ m = common_attn_metadata - assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens - and ((m.num_reqs + 1) * (self.num_spec + 1) - >= m.num_actual_tokens)), \ - "GDN only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." + assert ( + m.num_reqs <= self.decode_cudagraph_max_bs + and m.num_actual_tokens <= self.decode_cudagraph_max_bs), ( + f"GDN only supports decode-only full CUDAGraph capture. " + f"Make sure batch size ({m.num_reqs}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), " + f"and number of tokens ({m.num_actual_tokens}) <= " + f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).") - num_accepted_tokens = torch.full((m.num_reqs, ), - m.max_query_len, - dtype=torch.int32, - device=m.query_start_loc.device) - num_drafted_tokens = torch.full((m.num_reqs, ), - self.num_spec, - dtype=torch.int32, - device=m.query_start_loc.device) + num_accepted_tokens = torch.diff(m.query_start_loc) + num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu() + m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu() - # Fixes query-start loc for spec-sequence-indices. - m.query_start_loc = torch.arange(0, - m.num_actual_tokens + 1, - step=m.max_query_len, - device=m.query_start_loc.device, - dtype=torch.int32) - m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full( - (m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu')) - - return self.build(0, m, num_accepted_tokens, num_drafted_tokens) + return self.build(0, m, num_accepted_tokens, + num_decode_draft_tokens_cpu) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a1969463cbfb..cbf439aa697b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -360,8 +360,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): dtype=torch.int64) self.num_discarded_requests = 0 - self.num_draft_tokens = self._make_buffer(self.max_num_reqs, - dtype=torch.int32) + self.num_decode_draft_tokens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) self.num_accepted_tokens = self._make_buffer(self.max_num_reqs, dtype=torch.int64) @@ -1103,17 +1103,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Iterate over the dictionary rather than all requests since not all # requests have draft tokens. num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + # For chunked prefills, use -1 as mask rather than 0, as guided + # decoding may rollback speculative tokens. + num_decode_draft_tokens = np.full(num_reqs, -1, dtype=np.int32) for req_id, draft_token_ids in ( scheduler_output.scheduled_spec_decode_tokens.items()): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - + num_decode_draft_tokens[req_idx] = (len(draft_token_ids) if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx]) else -1) spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices - self.num_draft_tokens.np[:num_reqs] = num_draft_tokens - self.num_draft_tokens.np[num_reqs:].fill(0) - self.num_draft_tokens.copy_to_gpu() + + # For DECODE only cuda graph of some attention backends (e.g., GDN). + self.num_decode_draft_tokens.np[: + num_reqs] = num_decode_draft_tokens + self.num_decode_draft_tokens.np[num_reqs:].fill(-1) + self.num_decode_draft_tokens.copy_to_gpu() logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: @@ -1217,7 +1225,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens. gpu[:num_reqs], - num_draft_tokens=self.num_draft_tokens.gpu[:num_reqs], + num_decode_draft_tokens_cpu=self. + num_decode_draft_tokens.cpu[:num_reqs], ) if ubatch_slices is not None: