Remove build_for_cudagraph_capture method and use build directly

Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2025-08-25 17:07:47 +00:00
parent b8d7f55dbd
commit fb0089c536
9 changed files with 52 additions and 76 deletions

View File

@ -459,6 +459,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata: fast_build: bool = False) -> FlashInferMetadata:
# For full cudagraph capture, ensure decode-only mode
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# This is likely a cudagraph capture scenario
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
@ -577,22 +584,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
return attn_metadata return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with FlashInfer.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"FlashInfer only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting # TODO: The cascade wrapper currently does not support setting

View File

@ -43,6 +43,13 @@ class Mamba1AttentionMetadataBuilder(
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False, fast_build: bool = False,
) -> Mamba1AttentionMetadata: ) -> Mamba1AttentionMetadata:
# For full cudagraph capture, ensure decode-only mode
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# This is likely a cudagraph capture scenario
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]

View File

@ -101,6 +101,13 @@ class Mamba2AttentionMetadataBuilder(
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> Mamba2AttentionMetadata: fast_build: bool = False) -> Mamba2AttentionMetadata:
# For full cudagraph capture, ensure decode-only mode
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# This is likely a cudagraph capture scenario
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens

View File

@ -38,18 +38,3 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
device=device, device=device,
) )
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)

View File

@ -561,25 +561,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens=seq_lens, seq_lens=seq_lens,
) )
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
assert m.max_query_len == 1 # decode-only
return self.build(0, m)
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M: fast_build: bool = False) -> M:
# For full cudagraph capture, ensure decode-only mode
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# This is likely a cudagraph capture scenario
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len max_query_len = common_attn_metadata.max_query_len

View File

@ -254,19 +254,19 @@ class AiterFlashAttentionMetadataBuilder(
self.aot_sliding_window: Optional[tuple[int, int]] = None self.aot_sliding_window: Optional[tuple[int, int]] = None
self.total_tokens: int = 0 self.total_tokens: int = 0
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
self.total_tokens = self.model_config.max_model_len \
* self.vllm_config.scheduler_config.max_num_partial_prefills
res = self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
self.total_tokens = 0
return res
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> 'AiterFlashAttentionMetadata': fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
# Handle total_tokens for cudagraph capture scenarios
is_cudagraph_capture = (common_prefix_len == 0 and
common_attn_metadata.max_query_len == 1)
if is_cudagraph_capture:
original_total_tokens = self.total_tokens
self.total_tokens = self.model_config.max_model_len \
* self.vllm_config.scheduler_config.max_num_partial_prefills
else:
original_total_tokens = None
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len max_query_len = common_attn_metadata.max_query_len
@ -310,6 +310,11 @@ class AiterFlashAttentionMetadataBuilder(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
total_tokens=self.total_tokens, total_tokens=self.total_tokens,
) )
# Restore total_tokens value if this was a cudagraph capture
if is_cudagraph_capture:
self.total_tokens = original_total_tokens
return attn_metadata return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:

View File

@ -73,16 +73,6 @@ class TritonAttentionMetadataBuilder(
vllm_config.parallel_config) vllm_config.parallel_config)
self.headdim = model_config.get_head_size() self.headdim = model_config.get_head_size()
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> TritonAttentionMetadata:
attn_metadata = self.build(0, common_attn_metadata)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
@ -129,6 +119,14 @@ class TritonAttentionMetadataBuilder(
suffix_kv_lens=suffix_kv_lens, suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata,
) )
# Handle cudagraph capture optimizations
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata.seq_lens.fill_(1)
return attn_metadata return attn_metadata

View File

@ -205,16 +205,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
""" """
raise NotImplementedError raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata) -> M:
"""
Build attention metadata for CUDA graph capture. Uses build by default.
Subclasses that override this method should call self.build or
super().build_for_cudagraph_capture.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def build_for_drafting( def build_for_drafting(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,

View File

@ -2306,7 +2306,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for attn_group in self.attn_groups[kv_cache_group_id]: for attn_group in self.attn_groups[kv_cache_group_id]:
attn_metadata_i = attn_group.metadata_builder\ attn_metadata_i = attn_group.metadata_builder\
.build_for_cudagraph_capture(common_attn_metadata) .build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i