[Attention][CUDAGraph] Remove CG padding from attention backends (#29352)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-12-02 13:48:08 -05:00 committed by GitHub
parent 2d613de9ae
commit 1d93f11675
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 20 additions and 46 deletions

View File

@ -252,7 +252,6 @@ class MambaMixer(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
num_padded_decodes = attn_metadata.num_padded_decodes
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@ -281,7 +280,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor,
num_prefill_tokens,
num_prefills,
num_padded_decodes,
num_decode_tokens,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode(
state_indices_tensor: torch.Tensor,
num_prefill_tokens: int,
num_prefills: int,
num_padded_decodes: int,
num_decode_tokens: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
num_actual_tokens = num_prefill_tokens + num_decode_tokens
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
[num_decode_tokens, num_prefill_tokens],
dim=-1,
)
gate_d, gate_p = torch.split(
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
)
# num_padded_decodes accounts for CUDA graph padding when applicable
# num_decode_tokens accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[: num_padded_decodes + num_prefills],
[num_padded_decodes, num_prefills],
state_indices_tensor[: num_decode_tokens + num_prefills],
[num_decode_tokens, num_prefills],
dim=0,
)

View File

@ -254,17 +254,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
)
else:
has_initial_state = None
num_actual_tokens = (
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
)
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
# Prepare tensors for cudagraph
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
batch_size = m.num_actual_tokens
if (
self.use_full_cuda_graph
and num_prefills == 0
@ -272,9 +266,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True
)
@ -319,9 +310,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs
):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True
)
@ -344,7 +332,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
num_actual_tokens=m.num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,

View File

@ -31,7 +31,6 @@ class Mamba1AttentionMetadata:
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_padded_decodes: int
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
@ -68,7 +67,6 @@ class Mamba1AttentionMetadataBuilder(
has_initial_states_p = None
query_start_loc_p = None
padded_decodes = num_decodes
num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None
@ -125,11 +123,10 @@ class Mamba1AttentionMetadataBuilder(
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching:
@ -137,17 +134,15 @@ class Mamba1AttentionMetadataBuilder(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:padded_decodes
:num_decode_tokens
]
block_idx_last_scheduled_token[num_decodes:] = 0
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:padded_decodes
:num_decode_tokens
]
block_idx_last_computed_token[num_decodes:] = 0
return Mamba1AttentionMetadata(
query_start_loc_p=query_start_loc_p,
@ -157,7 +152,6 @@ class Mamba1AttentionMetadataBuilder(
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_padded_decodes=padded_decodes,
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token,

View File

@ -10,7 +10,6 @@ from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
CommonAttentionMetadata,
compute_causal_conv1d_metadata,
split_decodes_and_prefills,
@ -304,30 +303,25 @@ class Mamba2AttentionMetadataBuilder(
num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True
)
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_input_tokens
:num_decode_tokens
]
block_idx_last_scheduled_token[num_decodes:] = 0
self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True
)
block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_input_tokens
:num_decode_tokens
]
block_idx_last_computed_token[num_decodes:] = 0
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,

View File

@ -83,11 +83,10 @@ class ShortConvAttentionMetadataBuilder(
and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True
)
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = ShortConvAttentionMetadata(