mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 12:31:19 +08:00
[BugFix] Fix DCP Assert (AssertionError: DCP not support reorder_batch_threshold > 1 now.) (#28100)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
0ff05e3770
commit
d43ad5a757
@ -545,6 +545,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
metadata_cls: type[M] | None = None,
|
metadata_cls: type[M] | None = None,
|
||||||
|
supports_dcp_with_varlen: bool = False,
|
||||||
):
|
):
|
||||||
self.metadata_cls = (
|
self.metadata_cls = (
|
||||||
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
metadata_cls if metadata_cls is not None else MLACommonMetadata
|
||||||
@ -638,7 +639,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
|
|
||||||
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
|
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
|
||||||
self._init_reorder_batch_threshold(
|
self._init_reorder_batch_threshold(
|
||||||
self.reorder_batch_threshold, supports_spec_decode
|
self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate consistency between query_len_support and reorder_batch_threshold
|
# Validate consistency between query_len_support and reorder_batch_threshold
|
||||||
|
|||||||
@ -81,7 +81,12 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata
|
kv_cache_spec,
|
||||||
|
layer_names,
|
||||||
|
vllm_config,
|
||||||
|
device,
|
||||||
|
FlashAttnMLAMetadata,
|
||||||
|
supports_dcp_with_varlen=True,
|
||||||
)
|
)
|
||||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||||
self.fa_aot_schedule = get_flash_attn_version() == 3
|
self.fa_aot_schedule = get_flash_attn_version() == 3
|
||||||
|
|||||||
@ -264,7 +264,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def _init_reorder_batch_threshold(
|
def _init_reorder_batch_threshold(
|
||||||
self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False
|
self,
|
||||||
|
reorder_batch_threshold: int = 1,
|
||||||
|
supports_spec_as_decode: bool = False,
|
||||||
|
supports_dcp_with_varlen: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.reorder_batch_threshold = reorder_batch_threshold
|
self.reorder_batch_threshold = reorder_batch_threshold
|
||||||
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
if self.reorder_batch_threshold is not None and supports_spec_as_decode:
|
||||||
@ -281,6 +284,12 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
1 + speculative_config.num_speculative_tokens,
|
1 + speculative_config.num_speculative_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.vllm_config.parallel_config.decode_context_parallel_size > 1
|
||||||
|
and not supports_dcp_with_varlen
|
||||||
|
):
|
||||||
|
self.reorder_batch_threshold = 1
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user