diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0ec1573004197..413d20ce04021 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -545,6 +545,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): vllm_config: VllmConfig, device: torch.device, metadata_cls: type[M] | None = None, + supports_dcp_with_varlen: bool = False, ): self.metadata_cls = ( metadata_cls if metadata_cls is not None else MLACommonMetadata @@ -638,7 +639,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY self._init_reorder_batch_threshold( - self.reorder_batch_threshold, supports_spec_decode + self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen ) # Validate consistency between query_len_support and reorder_batch_threshold diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 6baf45efccb54..7b084ae969d97 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -81,7 +81,12 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] device: torch.device, ): super().__init__( - kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata + kv_cache_spec, + layer_names, + vllm_config, + device, + FlashAttnMLAMetadata, + supports_dcp_with_varlen=True, ) self.max_num_splits = 0 # No upper bound on the number of splits. self.fa_aot_schedule = get_flash_attn_version() == 3 diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 6c750d3448c41..ed0fae3828453 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -264,7 +264,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): self.device = device def _init_reorder_batch_threshold( - self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False + self, + reorder_batch_threshold: int = 1, + supports_spec_as_decode: bool = False, + supports_dcp_with_varlen: bool = False, ) -> None: self.reorder_batch_threshold = reorder_batch_threshold if self.reorder_batch_threshold is not None and supports_spec_as_decode: @@ -281,6 +284,12 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): 1 + speculative_config.num_speculative_tokens, ) + if ( + self.vllm_config.parallel_config.decode_context_parallel_size > 1 + and not supports_dcp_with_varlen + ): + self.reorder_batch_threshold = 1 + @abstractmethod def build( self,