[BugFix] Fix DCP Assert (AssertionError: DCP not support reorder_batch_threshold > 1 now.) (#28100)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-11-05 01:54:43 -05:00 committed by GitHub
parent 0ff05e3770
commit d43ad5a757
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 3 deletions

View File

@ -545,6 +545,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: type[M] | None = None,
supports_dcp_with_varlen: bool = False,
):
self.metadata_cls = (
metadata_cls if metadata_cls is not None else MLACommonMetadata
@ -638,7 +639,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_decode
self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
)
# Validate consistency between query_len_support and reorder_batch_threshold

View File

@ -81,7 +81,12 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
device: torch.device,
):
super().__init__(
kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata
kv_cache_spec,
layer_names,
vllm_config,
device,
FlashAttnMLAMetadata,
supports_dcp_with_varlen=True,
)
self.max_num_splits = 0 # No upper bound on the number of splits.
self.fa_aot_schedule = get_flash_attn_version() == 3

View File

@ -264,7 +264,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
self.device = device
def _init_reorder_batch_threshold(
self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False
self,
reorder_batch_threshold: int = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
@ -281,6 +284,12 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
1 + speculative_config.num_speculative_tokens,
)
if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
@abstractmethod
def build(
self,