diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 90e520e244416..0b63acf2dc5a5 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -252,7 +252,6 @@ class MambaMixer(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p - num_padded_decodes = attn_metadata.num_padded_decodes # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -281,7 +280,7 @@ class MambaMixer(MambaBase, CustomOp): state_indices_tensor, num_prefill_tokens, num_prefills, - num_padded_decodes, + num_decode_tokens, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d @@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode( state_indices_tensor: torch.Tensor, num_prefill_tokens: int, num_prefills: int, - num_padded_decodes: int, + num_decode_tokens: int, ) -> PrefillDecodeSplit: - num_actual_tokens = num_prefill_tokens + num_padded_decodes + num_actual_tokens = num_prefill_tokens + num_decode_tokens # In v1, decode tokens come first, then prefill tokens. hidden_states_BC_d, hidden_states_BC_p = torch.split( hidden_states_BC[..., :num_actual_tokens], - [num_padded_decodes, num_prefill_tokens], + [num_decode_tokens, num_prefill_tokens], dim=-1, ) gate_d, gate_p = torch.split( - gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1 + gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1 ) - # num_padded_decodes accounts for CUDA graph padding when applicable + # num_decode_tokens accounts for CUDA graph padding when applicable state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor[: num_padded_decodes + num_prefills], - [num_padded_decodes, num_prefills], + state_indices_tensor[: num_decode_tokens + num_prefills], + [num_decode_tokens, num_prefills], dim=0, ) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 69b5a6fb48564..e921f8c3de073 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -254,17 +254,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ) else: has_initial_state = None - num_actual_tokens = ( - num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens - ) - # prepare tensors for cudagraph - # - # With speculative decoding, the xgrammar backend may rollback tokens - # and causing some sequences has less draft tokens than self.num_spec. - # - # In above cases, the max possible batch size for n tokens, can be - # min(n, cudagraph_max_bs). + # Prepare tensors for cudagraph + # Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph + batch_size = m.num_actual_tokens + if ( self.use_full_cuda_graph and num_prefills == 0 @@ -272,9 +266,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] and num_spec_decodes <= self.decode_cudagraph_max_bs and num_spec_decode_tokens <= self.decode_cudagraph_max_bs ): - num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) - batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens) - self.spec_state_indices_tensor[:num_spec_decodes].copy_( spec_state_indices_tensor, non_blocking=True ) @@ -319,9 +310,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] and num_spec_decodes == 0 and num_decodes <= self.decode_cudagraph_max_bs ): - num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens) - batch_size = num_actual_tokens - self.non_spec_state_indices_tensor[:num_decodes].copy_( non_spec_state_indices_tensor, non_blocking=True ) @@ -344,7 +332,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] num_decode_tokens=num_decode_tokens, num_spec_decodes=num_spec_decodes, num_spec_decode_tokens=num_spec_decode_tokens, - num_actual_tokens=num_actual_tokens, + num_actual_tokens=m.num_actual_tokens, has_initial_state=has_initial_state, spec_query_start_loc=spec_query_start_loc, non_spec_query_start_loc=non_spec_query_start_loc, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 8e949e53330c1..fcda6134016ba 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -31,7 +31,6 @@ class Mamba1AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int - num_padded_decodes: int block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] @@ -68,7 +67,6 @@ class Mamba1AttentionMetadataBuilder( has_initial_states_p = None query_start_loc_p = None - padded_decodes = num_decodes num_computed_tokens, num_computed_tokens_p = None, None block_idx_first_scheduled_token = None block_idx_first_scheduled_token_p = None @@ -125,11 +123,10 @@ class Mamba1AttentionMetadataBuilder( and num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): - padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( state_indices_tensor, non_blocking=True ) - state_indices_tensor = self.state_indices_tensor[:padded_decodes] + state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID if self.vllm_config.cache_config.enable_prefix_caching: @@ -137,17 +134,15 @@ class Mamba1AttentionMetadataBuilder( block_idx_last_scheduled_token, non_blocking=True ) block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - :padded_decodes + :num_decode_tokens ] - block_idx_last_scheduled_token[num_decodes:] = 0 self.block_idx_last_computed_token[:num_decodes].copy_( block_idx_last_computed_token, non_blocking=True ) block_idx_last_computed_token = self.block_idx_last_computed_token[ - :padded_decodes + :num_decode_tokens ] - block_idx_last_computed_token[num_decodes:] = 0 return Mamba1AttentionMetadata( query_start_loc_p=query_start_loc_p, @@ -157,7 +152,6 @@ class Mamba1AttentionMetadataBuilder( num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, - num_padded_decodes=padded_decodes, block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, block_idx_last_computed_token=block_idx_last_computed_token, diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 888734e5d2b6b..bf1d8f09ab0ac 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -10,7 +10,6 @@ from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( - PAD_SLOT_ID, CommonAttentionMetadata, compute_causal_conv1d_metadata, split_decodes_and_prefills, @@ -304,30 +303,25 @@ class Mamba2AttentionMetadataBuilder( num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): - # Pad state tensor for CUDA graph - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( state_indices_tensor, non_blocking=True ) - state_indices_tensor = self.state_indices_tensor[:num_input_tokens] - state_indices_tensor[num_decodes:] = PAD_SLOT_ID + state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] if self.vllm_config.cache_config.enable_prefix_caching: self.block_idx_last_scheduled_token[:num_decodes].copy_( block_idx_last_scheduled_token, non_blocking=True ) block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ - :num_input_tokens + :num_decode_tokens ] - block_idx_last_scheduled_token[num_decodes:] = 0 self.block_idx_last_computed_token[:num_decodes].copy_( block_idx_last_computed_token, non_blocking=True ) block_idx_last_computed_token = self.block_idx_last_computed_token[ - :num_input_tokens + :num_decode_tokens ] - block_idx_last_computed_token[num_decodes:] = 0 attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index de0cb73db0917..c8fe0faf71088 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -83,11 +83,10 @@ class ShortConvAttentionMetadataBuilder( and num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) self.state_indices_tensor[:num_decodes].copy_( state_indices_tensor, non_blocking=True ) - state_indices_tensor = self.state_indices_tensor[:num_input_tokens] + state_indices_tensor = self.state_indices_tensor[:num_decode_tokens] state_indices_tensor[num_decodes:] = PAD_SLOT_ID attn_metadata = ShortConvAttentionMetadata(