[Attention] Remove unused reorder_batch method (#24463)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-10-06 13:13:39 -04:00 committed by GitHub
parent b8f603cebe
commit 4727a8afa7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 8 additions and 64 deletions

View File

@ -581,7 +581,7 @@ def _generate_fake_step_update(
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
if condensed_batch_size > 1:
# Simulate arbitrary reorder_batch() in the kernel backend
# Simulate arbitrary batch ordering in the kernel backend
# Generate a random number k of non-overlapping swap tuples
k = random.randint(0, condensed_batch_size // 2)
idxs = list(range(condensed_batch_size))

View File

@ -602,8 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
)
else:
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
# Decodes are at the front and prefills are at the back.
num_prefills = attn_metadata.num_prefills
num_decodes = attn_metadata.num_decodes
if num_prefills > 0:
@ -925,8 +924,7 @@ class FlashInferImpl(AttentionImpl):
stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
# Decodes are at the front and prefills are at the back.
if num_prefill_tokens > 0:
prefill_wrapper = attn_metadata.prefill_wrapper
prefill_query = query[num_decode_tokens:]

View File

@ -3,7 +3,7 @@
"""Attention layer with FlexAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from typing import Optional, Union
import torch
import torch._dynamo.decorators
@ -38,10 +38,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
create_block_mask_compiled = torch.compile(
create_block_mask, fullgraph=True, mode="reduce-overhead"
)
@ -600,11 +596,6 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return False
def build(
self,
common_prefix_len: int,

View File

@ -4,10 +4,11 @@
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
@ -20,17 +21,10 @@ from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@ -189,12 +183,7 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
device=device,
)
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0]
)
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
def build(
self,

View File

@ -299,24 +299,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
"""
raise NotImplementedError
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
input_batch: input batch
scheduler_output: scheduler output.
Returns:
True if the batch was modified, False otherwise.
"""
raise NotImplementedError
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills(
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if it's not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens

View File

@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
from typing import Optional
import torch
@ -19,7 +19,6 @@ from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache_interface import AttentionSpec
@ -35,10 +34,6 @@ try:
except ImportError:
XFORMERS_AVAILABLE = False
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm import _custom_ops as ops
logger = init_logger(__name__)
@ -223,13 +218,6 @@ class XFormersAttentionMetadataBuilder(
self._num_decodes = 0
self._num_decode_tokens = 0
def reorder_batch(
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
) -> bool:
return reorder_batch_to_split_decodes_and_prefills(
input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold
)
def build(
self,
common_prefix_len: int,