mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-19 08:57:12 +08:00
[V1] [Hybrid] Validate compatibility of attention backend batch reordering at init time (#21557)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
f5d0f4784f
commit
4abfd8796f
@ -4,7 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Union
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
|
||||
@ -21,17 +21,17 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, is_pin_memory_available
|
||||
from vllm.utils.flashinfer import use_trtllm_decode_attention
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
get_kv_cache_layout, get_per_layer_parameters,
|
||||
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
|
||||
split_decodes_and_prefills)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -179,6 +179,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
self.device = device
|
||||
@ -239,12 +241,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
|
||||
@ -2,21 +2,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
@ -87,6 +83,8 @@ class Mamba2AttentionMetadata:
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
@ -95,12 +93,6 @@ class Mamba2AttentionMetadataBuilder(
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
|
||||
@ -190,7 +190,7 @@ return curr_o @ W_O
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
||||
from typing import ClassVar, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -210,10 +210,11 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.utils.flashinfer import has_nvidia_artifactory
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
get_per_layer_parameters, infer_global_hyperparameters,
|
||||
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
try:
|
||||
@ -233,10 +234,6 @@ try:
|
||||
except ImportError:
|
||||
flashinfer_available = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CUDNN_WORKSPACE_SIZE = 12800
|
||||
@ -403,6 +400,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
@ -559,12 +557,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
prefill.prefill_chunks = self._fi_prefill_chunks
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return reorder_batch_to_split_decodes_and_prefills(input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=1)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor):
|
||||
return MLACommonDecodeMetadata(
|
||||
|
||||
@ -251,9 +251,6 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
self.aot_sliding_window: Optional[tuple[int, int]] = None
|
||||
self.total_tokens: int = 0
|
||||
|
||||
def reorder_batch(self, input_batch, scheduler_output) -> bool:
|
||||
return False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata):
|
||||
self.total_tokens = self.model_config.max_model_len \
|
||||
|
||||
@ -167,6 +167,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
# Does this backend/builder support CUDA Graphs for attention.
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
reorder_batch_threshold: ClassVar[Optional[int]] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
@ -221,14 +225,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
This method can reorder the batch if desired by the backend.
|
||||
:return: Has the batch been reordered (default False).
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -9,8 +9,12 @@ import torch.nn as nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -27,6 +31,34 @@ class CPUModelRunner(GPUModelRunner):
|
||||
|
||||
self._postprocess_tenosrs()
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
want to separate requests based on if the attention computation will be
|
||||
compute-bound or memory-bound.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
"""
|
||||
# Attention free models have zero kv_cache_goups, however models
|
||||
# like Mamba are also attention free but use the kv_cache for
|
||||
# keeping its internal state. This is why we check the number
|
||||
# of kv_cache groups instead of solely checking
|
||||
# for self.model_config.is_attention_free.
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return
|
||||
|
||||
if len(self.kv_cache_config.kv_cache_groups) > 1:
|
||||
raise ValueError("Multiple KVCacheGroups is not"
|
||||
"currently supported with CPU model runner.")
|
||||
|
||||
assert type(
|
||||
self.attn_metadata_builders[0]) is TorchSDPAMetadataBuilderV1
|
||||
|
||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
|
||||
def _postprocess_tenosrs(self) -> None:
|
||||
# Note: replace device tensors with cpu tensors
|
||||
def replace_tensor(obj: Any, cpu_attr_name: str,
|
||||
|
||||
@ -49,7 +49,8 @@ from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
|
||||
make_kv_sharing_fast_prefill_attention_metadata,
|
||||
make_local_attention_virtual_batches)
|
||||
make_local_attention_virtual_batches,
|
||||
reorder_batch_to_split_decodes_and_prefills)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
@ -329,6 +330,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
||||
|
||||
self.reorder_batch_threshold: Optional[int] = None
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
@ -347,20 +350,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return
|
||||
|
||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
|
||||
# For models with multiple KV cache groups, the groups should agree on
|
||||
# the same order of requests. We ensure this by only allowing the first
|
||||
# group to reorder the batch and asserting that all other groups do not
|
||||
# reorder the batch.
|
||||
# TODO(tdoublep): make this more flexible so that any group can
|
||||
# re-order the batch (not only the first).
|
||||
# TODO(tdoublep): verify this during engine init instead of at runtime
|
||||
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
||||
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
assert not batch_reordered
|
||||
if self.reorder_batch_threshold is not None:
|
||||
reorder_batch_to_split_decodes_and_prefills(
|
||||
self.input_batch,
|
||||
scheduler_output,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
|
||||
# Note: used for model runner override.
|
||||
def _init_device_properties(self) -> None:
|
||||
@ -2654,6 +2648,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
|
||||
# Calculate reorder batch threshold (if neeeded)
|
||||
self.calculate_reorder_batch_threshold()
|
||||
|
||||
if len(self.attn_backends) > 0:
|
||||
return
|
||||
|
||||
@ -2688,6 +2685,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.attn_metadata_builders.append(attn_metadata_builder)
|
||||
self.is_encoder_only_model = True
|
||||
|
||||
def calculate_reorder_batch_threshold(self) -> None:
|
||||
"""
|
||||
Check that if any backends reorder batches; that the reordering
|
||||
is compatible (e.g., decode threshold is the same)
|
||||
"""
|
||||
for attn_metadata_builder_i in self.attn_metadata_builders:
|
||||
# check that if any backends reorder batches; that the reordering
|
||||
# is compatible (e.g., decode threshold is the same)
|
||||
reorder_batch_threshold_i = (
|
||||
attn_metadata_builder_i.reorder_batch_threshold)
|
||||
if reorder_batch_threshold_i is not None:
|
||||
if self.reorder_batch_threshold is not None:
|
||||
if reorder_batch_threshold_i != \
|
||||
self.reorder_batch_threshold:
|
||||
raise ValueError(
|
||||
f"Attention backend reorders decodes with "
|
||||
f"threshold {reorder_batch_threshold_i} but other "
|
||||
f"backend uses threshold "
|
||||
f"{self.reorder_batch_threshold}")
|
||||
else:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold_i
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user