mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Attention] Remove unused reorder_batch method (#24463)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
b8f603cebe
commit
4727a8afa7
@ -581,7 +581,7 @@ def _generate_fake_step_update(
|
|||||||
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
|
persistent_batch[:] = persistent_batch[0:condensed_batch_size]
|
||||||
|
|
||||||
if condensed_batch_size > 1:
|
if condensed_batch_size > 1:
|
||||||
# Simulate arbitrary reorder_batch() in the kernel backend
|
# Simulate arbitrary batch ordering in the kernel backend
|
||||||
# Generate a random number k of non-overlapping swap tuples
|
# Generate a random number k of non-overlapping swap tuples
|
||||||
k = random.randint(0, condensed_batch_size // 2)
|
k = random.randint(0, condensed_batch_size // 2)
|
||||||
idxs = list(range(condensed_batch_size))
|
idxs = list(range(condensed_batch_size))
|
||||||
|
|||||||
@ -602,8 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
# Decodes are at the front and prefills are at the back,
|
# Decodes are at the front and prefills are at the back.
|
||||||
# according to reorder_batch()
|
|
||||||
num_prefills = attn_metadata.num_prefills
|
num_prefills = attn_metadata.num_prefills
|
||||||
num_decodes = attn_metadata.num_decodes
|
num_decodes = attn_metadata.num_decodes
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
@ -925,8 +924,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||||
kv_cache_permute = kv_cache.permute(*stride_order)
|
kv_cache_permute = kv_cache.permute(*stride_order)
|
||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
# Decodes are at the front and prefills are at the back,
|
# Decodes are at the front and prefills are at the back.
|
||||||
# according to reorder_batch()
|
|
||||||
if num_prefill_tokens > 0:
|
if num_prefill_tokens > 0:
|
||||||
prefill_wrapper = attn_metadata.prefill_wrapper
|
prefill_wrapper = attn_metadata.prefill_wrapper
|
||||||
prefill_query = query[num_decode_tokens:]
|
prefill_query = query[num_decode_tokens:]
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
"""Attention layer with FlexAttention."""
|
"""Attention layer with FlexAttention."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.decorators
|
import torch._dynamo.decorators
|
||||||
@ -38,10 +38,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
||||||
|
|
||||||
create_block_mask_compiled = torch.compile(
|
create_block_mask_compiled = torch.compile(
|
||||||
create_block_mask, fullgraph=True, mode="reduce-overhead"
|
create_block_mask, fullgraph=True, mode="reduce-overhead"
|
||||||
)
|
)
|
||||||
@ -600,11 +596,6 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
|||||||
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
|
self.q_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
|
||||||
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
|
self.kv_block_size: int = 16 if is_torch_equal_or_newer("2.9.0.dev0") else 128
|
||||||
|
|
||||||
def reorder_batch(
|
|
||||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
|
||||||
) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
|||||||
@ -4,10 +4,11 @@
|
|||||||
|
|
||||||
import ast
|
import ast
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
@ -20,17 +21,10 @@ from vllm.logger import init_logger
|
|||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
reorder_batch_to_split_decodes_and_prefills,
|
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -189,12 +183,7 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def reorder_batch(
|
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
|
||||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
|
||||||
) -> bool:
|
|
||||||
return reorder_batch_to_split_decodes_and_prefills(
|
|
||||||
input_batch, scheduler_output, decode_threshold=self.tree_attn_bias.shape[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -299,24 +299,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def reorder_batch(
|
|
||||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Update the order of requests in the batch based on the attention
|
|
||||||
backend's needs. For example, some attention backends (namely MLA) may
|
|
||||||
want to separate requests based on if the attention computation will be
|
|
||||||
compute-bound or memory-bound.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_batch: input batch
|
|
||||||
scheduler_output: scheduler output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the batch was modified, False otherwise.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
def build_for_cudagraph_capture(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
) -> M:
|
) -> M:
|
||||||
@ -828,10 +810,6 @@ def reorder_batch_to_split_decodes_and_prefills(
|
|||||||
|
|
||||||
for i, req_id in enumerate(input_batch.req_ids):
|
for i, req_id in enumerate(input_batch.req_ids):
|
||||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
# for now treat 1 scheduled token as "decode" even if it's not,
|
|
||||||
# we should update this to something like < 8 in the future but
|
|
||||||
# currently the TritonMLA._forward_decode only supports
|
|
||||||
# num_tokens = 1
|
|
||||||
if num_tokens <= decode_threshold:
|
if num_tokens <= decode_threshold:
|
||||||
decodes.append(i)
|
decodes.append(i)
|
||||||
num_decode_tokens += num_tokens
|
num_decode_tokens += num_tokens
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
"""Attention layer with XFormersAttention."""
|
"""Attention layer with XFormersAttention."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -19,7 +19,6 @@ from vllm.logger import init_logger
|
|||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionMetadataBuilder,
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
reorder_batch_to_split_decodes_and_prefills,
|
|
||||||
split_decodes_and_prefills,
|
split_decodes_and_prefills,
|
||||||
)
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
@ -35,10 +34,6 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
XFORMERS_AVAILABLE = False
|
XFORMERS_AVAILABLE = False
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
|
||||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -223,13 +218,6 @@ class XFormersAttentionMetadataBuilder(
|
|||||||
self._num_decodes = 0
|
self._num_decodes = 0
|
||||||
self._num_decode_tokens = 0
|
self._num_decode_tokens = 0
|
||||||
|
|
||||||
def reorder_batch(
|
|
||||||
self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput"
|
|
||||||
) -> bool:
|
|
||||||
return reorder_batch_to_split_decodes_and_prefills(
|
|
||||||
input_batch, scheduler_output, decode_threshold=self.reorder_batch_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user