mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 15:02:17 +08:00
Remove build_for_cudagraph_capture method and use build directly
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
parent
b8d7f55dbd
commit
fb0089c536
@ -459,6 +459,13 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False) -> FlashInferMetadata:
|
fast_build: bool = False) -> FlashInferMetadata:
|
||||||
|
# For full cudagraph capture, ensure decode-only mode
|
||||||
|
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||||
|
# This is likely a cudagraph capture scenario
|
||||||
|
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
|
||||||
|
"FlashInfer only supports decode-only full CUDAGraph capture. " \
|
||||||
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||||
|
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
|
||||||
@ -577,22 +584,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
|
||||||
self, common_attn_metadata: CommonAttentionMetadata):
|
|
||||||
"""
|
|
||||||
This method builds the metadata for full cudagraph capture.
|
|
||||||
Currently, only decode is supported for full cudagraphs with FlashInfer.
|
|
||||||
"""
|
|
||||||
m = common_attn_metadata
|
|
||||||
|
|
||||||
assert m.num_reqs == m.num_actual_tokens, \
|
|
||||||
"FlashInfer only supports decode-only full CUDAGraph capture. " \
|
|
||||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
|
||||||
|
|
||||||
m.max_query_len = 1 # decode-only
|
|
||||||
|
|
||||||
return self.build(0, m)
|
|
||||||
|
|
||||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||||
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
|
||||||
# TODO: The cascade wrapper currently does not support setting
|
# TODO: The cascade wrapper currently does not support setting
|
||||||
|
|||||||
@ -43,6 +43,13 @@ class Mamba1AttentionMetadataBuilder(
|
|||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False,
|
fast_build: bool = False,
|
||||||
) -> Mamba1AttentionMetadata:
|
) -> Mamba1AttentionMetadata:
|
||||||
|
# For full cudagraph capture, ensure decode-only mode
|
||||||
|
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||||
|
# This is likely a cudagraph capture scenario
|
||||||
|
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
|
||||||
|
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||||
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||||
|
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||||
|
|||||||
@ -101,6 +101,13 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
fast_build: bool = False) -> Mamba2AttentionMetadata:
|
||||||
|
# For full cudagraph capture, ensure decode-only mode
|
||||||
|
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||||
|
# This is likely a cudagraph capture scenario
|
||||||
|
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
|
||||||
|
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||||
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||||
|
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
|||||||
@ -38,18 +38,3 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
|
||||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
|
||||||
"""
|
|
||||||
This method builds the metadata for full cudagraph capture.
|
|
||||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
|
||||||
"""
|
|
||||||
m = common_attn_metadata
|
|
||||||
|
|
||||||
assert m.num_reqs == m.num_actual_tokens, \
|
|
||||||
"Mamba only supports decode-only full CUDAGraph capture. " \
|
|
||||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
|
||||||
|
|
||||||
m.max_query_len = 1 # decode-only
|
|
||||||
|
|
||||||
return self.build(0, m)
|
|
||||||
@ -561,25 +561,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
|
||||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
|
||||||
"""
|
|
||||||
This method builds the metadata for full cudagraph capture.
|
|
||||||
Currently, only decode is supported for full cudagraphs with MLA.
|
|
||||||
"""
|
|
||||||
m = common_attn_metadata
|
|
||||||
assert m.num_reqs == m.num_actual_tokens, \
|
|
||||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
|
||||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
|
||||||
|
|
||||||
assert m.max_query_len == 1 # decode-only
|
|
||||||
|
|
||||||
return self.build(0, m)
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False) -> M:
|
fast_build: bool = False) -> M:
|
||||||
|
# For full cudagraph capture, ensure decode-only mode
|
||||||
|
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||||
|
# This is likely a cudagraph capture scenario
|
||||||
|
assert common_attn_metadata.num_reqs == common_attn_metadata.num_actual_tokens, \
|
||||||
|
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||||
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||||
|
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_tokens = common_attn_metadata.num_actual_tokens
|
num_tokens = common_attn_metadata.num_actual_tokens
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
|
|||||||
@ -254,19 +254,19 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||||
self.total_tokens: int = 0
|
self.total_tokens: int = 0
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
|
||||||
self, common_attn_metadata: CommonAttentionMetadata):
|
|
||||||
self.total_tokens = self.model_config.max_model_len \
|
|
||||||
* self.vllm_config.scheduler_config.max_num_partial_prefills
|
|
||||||
res = self.build(common_prefix_len=0,
|
|
||||||
common_attn_metadata=common_attn_metadata)
|
|
||||||
self.total_tokens = 0
|
|
||||||
return res
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
|
fast_build: bool = False) -> 'AiterFlashAttentionMetadata':
|
||||||
|
# Handle total_tokens for cudagraph capture scenarios
|
||||||
|
is_cudagraph_capture = (common_prefix_len == 0 and
|
||||||
|
common_attn_metadata.max_query_len == 1)
|
||||||
|
if is_cudagraph_capture:
|
||||||
|
original_total_tokens = self.total_tokens
|
||||||
|
self.total_tokens = self.model_config.max_model_len \
|
||||||
|
* self.vllm_config.scheduler_config.max_num_partial_prefills
|
||||||
|
else:
|
||||||
|
original_total_tokens = None
|
||||||
|
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
@ -310,6 +310,11 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
total_tokens=self.total_tokens,
|
total_tokens=self.total_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Restore total_tokens value if this was a cudagraph capture
|
||||||
|
if is_cudagraph_capture:
|
||||||
|
self.total_tokens = original_total_tokens
|
||||||
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||||
|
|||||||
@ -73,16 +73,6 @@ class TritonAttentionMetadataBuilder(
|
|||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
self.headdim = model_config.get_head_size()
|
self.headdim = model_config.get_head_size()
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
|
||||||
self, common_attn_metadata: CommonAttentionMetadata
|
|
||||||
) -> TritonAttentionMetadata:
|
|
||||||
attn_metadata = self.build(0, common_attn_metadata)
|
|
||||||
# When doing full graph capture, setting seq_lens to
|
|
||||||
# max_model_len will cause graph capture to be extremely
|
|
||||||
# slow, so here we set it to 1.
|
|
||||||
attn_metadata.seq_lens.fill_(1)
|
|
||||||
return attn_metadata
|
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
@ -129,6 +119,14 @@ class TritonAttentionMetadataBuilder(
|
|||||||
suffix_kv_lens=suffix_kv_lens,
|
suffix_kv_lens=suffix_kv_lens,
|
||||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle cudagraph capture optimizations
|
||||||
|
if common_prefix_len == 0 and common_attn_metadata.max_query_len == 1:
|
||||||
|
# When doing full graph capture, setting seq_lens to
|
||||||
|
# max_model_len will cause graph capture to be extremely
|
||||||
|
# slow, so here we set it to 1.
|
||||||
|
attn_metadata.seq_lens.fill_(1)
|
||||||
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -205,16 +205,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
|
||||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
|
||||||
"""
|
|
||||||
Build attention metadata for CUDA graph capture. Uses build by default.
|
|
||||||
Subclasses that override this method should call self.build or
|
|
||||||
super().build_for_cudagraph_capture.
|
|
||||||
"""
|
|
||||||
return self.build(common_prefix_len=0,
|
|
||||||
common_attn_metadata=common_attn_metadata)
|
|
||||||
|
|
||||||
def build_for_drafting(
|
def build_for_drafting(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
|||||||
@ -2306,7 +2306,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||||
attn_metadata_i = attn_group.metadata_builder\
|
attn_metadata_i = attn_group.metadata_builder\
|
||||||
.build_for_cudagraph_capture(common_attn_metadata)
|
.build(common_prefix_len=0,
|
||||||
|
common_attn_metadata=common_attn_metadata)
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user