mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 10:21:48 +08:00
[BugFix] Fix DCP Assert (AssertionError: DCP not support reorder_batch_threshold > 1 now.) (#28100)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
0ff05e3770
commit
d43ad5a757
@ -545,6 +545,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: type[M] | None = None,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
):
|
||||
self.metadata_cls = (
|
||||
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
||||
@ -638,7 +639,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
|
||||
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
|
||||
self._init_reorder_batch_threshold(
|
||||
self.reorder_batch_threshold, supports_spec_decode
|
||||
self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
|
||||
)
|
||||
|
||||
# Validate consistency between query_len_support and reorder_batch_threshold
|
||||
|
||||
@ -81,7 +81,12 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
FlashAttnMLAMetadata,
|
||||
supports_dcp_with_varlen=True,
|
||||
)
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.fa_aot_schedule = get_flash_attn_version() == 3
|
||||
|
||||
@ -264,7 +264,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
self.device = device
|
||||
|
||||
def _init_reorder_batch_threshold(
|
||||
self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False
|
||||
self,
|
||||
reorder_batch_threshold: int = 1,
|
||||
supports_spec_as_decode: bool = False,
|
||||
supports_dcp_with_varlen: bool = False,
|
||||
) -> None:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold
|
||||
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
||||
@ -281,6 +284,12 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
1 + speculative_config.num_speculative_tokens,
|
||||
)
|
||||
|
||||
if (
|
||||
self.vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||
and not supports_dcp_with_varlen
|
||||
):
|
||||
self.reorder_batch_threshold = 1
|
||||
|
||||
@abstractmethod
|
||||
def build(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user