diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index 34997b7e7a43..538b6281f5a0 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -581,7 +581,7 @@ def _generate_fake_step_update( persistent_batch[:] = persistent_batch[0:condensed_batch_size] if condensed_batch_size > 1: - # Simulate arbitrary reorder_batch() in the kernel backend + # Simulate arbitrary batch ordering in the kernel backend # Generate a random number k of non-overlapping swap tuples k = random.randint(0, condensed_batch_size // 2) idxs = list(range(condensed_batch_size)) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c7a826a67d7d..38cf0ca56733 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -602,8 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) else: # Regular attention (common case). - # Decodes are at the front and prefills are at the back, - # according to reorder_batch() + # Decodes are at the front and prefills are at the back. num_prefills = attn_metadata.num_prefills num_decodes = attn_metadata.num_decodes if num_prefills > 0: @@ -925,8 +924,7 @@ class FlashInferImpl(AttentionImpl): stride_order = FlashInferBackend.get_kv_cache_stride_order() kv_cache_permute = kv_cache.permute(*stride_order) # Regular attention (common case). - # Decodes are at the front and prefills are at the back, - # according to reorder_batch() + # Decodes are at the front and prefills are at the back. if num_prefill_tokens > 0: prefill_wrapper = attn_metadata.prefill_wrapper prefill_query = query[num_decode_tokens:] diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 4640e62abfe6..7775445ae773 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -3,7 +3,7 @@ """Attention layer with FlexAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, Union import torch import torch._dynamo.decorators @@ -38,10 +38,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - create_block_mask_compiled = torch.compile( create_block_mask, fullgraph=True, mode="reduce-overhead" ) @@ -600,11 +596,6 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128 - def reorder_batch( - self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" - ) -> bool: - return False - def build( self, common_prefix_len: int, diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 2a7770c87d24..a209bb79580c 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -4,10 +4,11 @@ import ast from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -20,17 +21,10 @@ from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - -from vllm import _custom_ops as ops - logger = init_logger(__name__) @@ -189,12 +183,7 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat device=device, ) - def reorder_batch( - self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" - ) -> bool: - return reorder_batch_to_split_decodes_and_prefills( - input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0] - ) + self.reorder_batch_threshold = self.tree_attn_bias.shape[0] def build( self, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index bddb2f22f0dc..003c7253e553 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -299,24 +299,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): """ raise NotImplementedError - def reorder_batch( - self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" - ) -> bool: - """ - Update the order of requests in the batch based on the attention - backend's needs. For example, some attention backends (namely MLA) may - want to separate requests based on if the attention computation will be - compute-bound or memory-bound. - - Args: - input_batch: input batch - scheduler_output: scheduler output. - - Returns: - True if the batch was modified, False otherwise. - """ - raise NotImplementedError - def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: @@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills( for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if it's not, - # we should update this to something like < 8 in the future but - # currently the TritonMLA._forward_decode only supports - # num_tokens = 1 if num_tokens <= decode_threshold: decodes.append(i) num_decode_tokens += num_tokens diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 17e752277c66..b21562fac741 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import Optional import torch @@ -19,7 +19,6 @@ from vllm.logger import init_logger from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, - reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -35,10 +34,6 @@ try: except ImportError: XFORMERS_AVAILABLE = False -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm import _custom_ops as ops logger = init_logger(__name__) @@ -223,13 +218,6 @@ class XFormersAttentionMetadataBuilder( self._num_decodes = 0 self._num_decode_tokens = 0 - def reorder_batch( - self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput" - ) -> bool: - return reorder_batch_to_split_decodes_and_prefills( - input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold - ) - def build( self, common_prefix_len: int,