mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[Attention] Remove unused reorder_batch method (#24463)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
b8f603cebe
commit
4727a8afa7
@ -581,7 +581,7 @@ def _generate_fake_step_update(
|
||||
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
|
||||
|
||||
if condensed_batch_size > 1:
|
||||
# Simulate arbitrary reorder_batch() in the kernel backend
|
||||
# Simulate arbitrary batch ordering in the kernel backend
|
||||
# Generate a random number k of non-overlapping swap tuples
|
||||
k = random.randint(0, condensed_batch_size // 2)
|
||||
idxs = list(range(condensed_batch_size))
|
||||
|
||||
@ -602,8 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
else:
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
# Decodes are at the front and prefills are at the back.
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
if num_prefills > 0:
|
||||
@ -925,8 +924,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||
kv_cache_permute = kv_cache.permute(*stride_order)
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
# Decodes are at the front and prefills are at the back.
|
||||
if num_prefill_tokens > 0:
|
||||
prefill_wrapper = attn_metadata.prefill_wrapper
|
||||
prefill_query = query[num_decode_tokens:]
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with FlexAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo.decorators
|
||||
@ -38,10 +38,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
create_block_mask_compiled = torch.compile(
|
||||
create_block_mask, fullgraph=True, mode="reduce-overhead"
|
||||
)
|
||||
@ -600,11 +596,6 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
||||
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
|
||||
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
|
||||
|
||||
def reorder_batch(
|
||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
|
||||
@ -4,10 +4,11 @@
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
@ -20,17 +21,10 @@ from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -189,12 +183,7 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
|
||||
device=device,
|
||||
)
|
||||
|
||||
def reorder_batch(
|
||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
||||
) -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0]
|
||||
)
|
||||
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
|
||||
|
||||
def build(
|
||||
self,
|
||||
|
||||
@ -299,24 +299,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reorder_batch(
|
||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
||||
) -> bool:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
want to separate requests based on if the attention computation will be
|
||||
compute-bound or memory-bound.
|
||||
|
||||
Args:
|
||||
input_batch: input batch
|
||||
scheduler_output: scheduler output.
|
||||
|
||||
Returns:
|
||||
True if the batch was modified, False otherwise.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills(
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if it's not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the TritonMLA._forward_decode only supports
|
||||
# num_tokens = 1
|
||||
if num_tokens <= decode_threshold:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -19,7 +19,6 @@ from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
@ -35,10 +34,6 @@ try:
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -223,13 +218,6 @@ class XFormersAttentionMetadataBuilder(
|
||||
self._num_decodes = 0
|
||||
self._num_decode_tokens = 0
|
||||
|
||||
def reorder_batch(
|
||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
||||
) -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user