[Core] Add AttentionState abstraction (#7663)

This commit is contained in:
Antoni Baum 2024-08-20 11:50:45 -07:00 committed by GitHub
parent c6af027a35
commit 3b682179dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 372 additions and 247 deletions

View File

@ -5,6 +5,7 @@ import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
@ -29,7 +30,11 @@ class MockAttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise AttentionMetadataBuilder
return AttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(

View File

@ -1,7 +1,7 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionState, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
@ -12,5 +12,6 @@ __all__ = [
"AttentionType",
"AttentionMetadataBuilder",
"Attention",
"AttentionState",
"get_attn_backend",
]

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
@ -7,7 +8,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import torch
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerInputBuilderBase)
class AttentionType(Enum):
@ -34,6 +37,11 @@ class AttentionBackend(ABC):
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@ -126,6 +134,47 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata)
class AttentionState(ABC, Generic[T]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""
@abstractmethod
def __init__(self, runner: "ModelRunnerBase"):
...
@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield
@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...
@abstractmethod
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@abstractmethod
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@abstractmethod
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
attn_metadata: T) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...
@abstractmethod
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
"""Prepare state for forward pass."""
...
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

View File

@ -5,7 +5,8 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
@ -98,6 +99,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,

View File

@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@ -142,6 +143,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,

View File

@ -1,14 +1,19 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
import vllm.attention.backends.flash_attn # noqa
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
@ -16,7 +21,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionState, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
@ -46,6 +51,10 @@ class FlashInferBackend(AttentionBackend):
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder
@staticmethod
def get_state_cls() -> Type["FlashInferState"]:
return FlashInferState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
@ -75,6 +84,160 @@ class FlashInferBackend(AttentionBackend):
return [64, 128, 256]
class FlashInferState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
self._workspace_buffer = None
self._decode_wrapper = None
self._prefill_wrapper = None
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
return self._workspace_buffer
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
return self._prefill_wrapper
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
self._graph_indices_buffer = torch.empty(
max_batch_size * self.runner.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.runner.device)
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
self._graph_last_page_len_buffer = torch.empty(
max_batch_size, dtype=torch.int32, device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._graph_decode_workspace_buffer
del self._graph_indices_buffer
del self._graph_indptr_buffer
del self._graph_last_page_len_buffer
del self._graph_decode_wrapper
def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
state = self.__class__(self.runner)
state._workspace_buffer = self._graph_decode_workspace_buffer
state._decode_wrapper = self._graph_decode_wrapper
state._prefill_wrapper = self._get_prefill_wrapper()
return state
def graph_capture_get_metadata_for_batch(self, batch_size: int):
assert self._is_graph_capturing
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(0,
batch_size,
dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
self.runner.block_size,
dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=self._graph_block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.runner.model_config.get_head_size(),
page_size=self.runner.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.runner.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
return attn_metadata
def get_graph_input_buffers(self, attn_metadata):
return {
"slot_mapping": attn_metadata.slot_mapping,
}
def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
return
def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
model_input.attn_metadata.begin_forward()
@dataclass
class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding

View File

@ -8,6 +8,7 @@ import torch
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
@ -28,6 +29,10 @@ class IpexAttnBackend(AttentionBackend):
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,

View File

@ -1,11 +1,12 @@
from dataclasses import dataclass
from typing import List, Tuple
from typing import List, Tuple, Type
import openvino as ov
import torch
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import CommonAttentionState
class OpenVINOAttentionBackend(AttentionBackend):
@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs)

View File

@ -6,6 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
class PallasAttentionBackend(AttentionBackend):
@ -18,6 +19,10 @@ class PallasAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,

View File

@ -7,7 +7,8 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@ -33,6 +34,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
return ROCmFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,

View File

@ -8,6 +8,7 @@ from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu
@ -34,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,

View File

@ -1,12 +1,17 @@
"""Attention backend utils"""
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
import numpy as np
import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
@ -269,3 +274,69 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class CommonAttentionState(AttentionState):
def __init__(self, runner: "ModelRunnerBase"):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(self, batch_size: int):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
)
return attn_metadata
def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
return {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
def prepare_graph_input_buffers(self, input_buffers,
attn_metadata) -> None:
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
def begin_forward(self, model_input) -> None:
return

View File

@ -11,7 +11,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@ -37,6 +38,10 @@ class XFormersBackend(AttentionBackend):
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
return XFormersMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,

View File

@ -11,17 +11,6 @@ except ModuleNotFoundError:
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
@ -90,11 +79,6 @@ class TP1DraftModelRunner(ModelRunner):
observability_config=observability_config,
)
self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
@ -270,36 +254,7 @@ class TP1DraftModelRunner(ModelRunner):
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
if self.flashinfer_decode_workspace_buffer is None:
self.flashinfer_decode_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_decode_wrapper = \
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_decode_workspace_buffer, "NHD")
self.flashinfer_prefill_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_prefill_wrapper = \
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_prefill_workspace_buffer, "NHD")
model_input.attn_metadata.prefill_wrapper = \
self.flashinfer_prefill_wrapper
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = \
self.graph_runners[model_input.
virtual_engine][batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
model_input.attn_metadata.begin_forward()
self.attn_state.begin_forward(model_input)
# Detect exec mode
assert model_input.attn_metadata is not None

View File

@ -6,6 +6,7 @@ import torch.distributed
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
@ -20,7 +21,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase,
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
@ -395,7 +396,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
else:
for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[

View File

@ -13,19 +13,10 @@ import torch
import torch.distributed
import torch.nn as nn
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
@ -52,8 +43,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available)
flatten_2d_lists, is_hip, is_pin_memory_available)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict,
@ -66,7 +56,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
@ -858,6 +847,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.kv_cache_dtype,
self.block_size,
) if num_attn_heads else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self))
else:
self.attn_state = CommonAttentionState(weakref.proxy(self))
# Multi-modal data support
self.input_registry = input_registry
@ -872,11 +866,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
@ -1203,10 +1192,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID)
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
intermediate_inputs = None
if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors(
@ -1226,102 +1211,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
if self.attn_backend.get_name() == "flashinfer":
# For flashinfer, different batch sizes will share the
# same workspace buffer.
decode_workspace_buffer = \
torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
indices_buffer = torch.empty(max_batch_size *
self.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.device)
indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.device)
last_page_len_buffer = torch.empty(max_batch_size,
dtype=torch.int32,
device=self.device)
with graph_capture() as graph_capture_context:
with self.attn_state.graph_capture(
max_batch_size), graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list):
if self.attn_backend.get_name() == "flashinfer":
_indptr_buffer = indptr_buffer[:batch_size + 1]
_last_page_len_buffer = last_page_len_buffer[:
batch_size]
num_qo_heads = (
self.model_config.get_num_attention_heads(
self.parallel_config))
num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config)
if num_qo_heads // num_kv_heads >= 4:
use_tensor_cores = True
else:
use_tensor_cores = False
decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, _indptr_buffer,
indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(
0, batch_size + 1, dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(
0, batch_size, dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full(
(batch_size, ), self.block_size, dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=
paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
else:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size))
if self.lora_config:
lora_mapping = LoRAMapping(
@ -1339,17 +1238,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
set(), prompt_adapter_mapping)
graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name())
if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = _indptr_buffer
graph_runner.flashinfer_indices_buffer = indices_buffer
graph_runner.flashinfer_last_page_len_buffer = \
_last_page_len_buffer
graph_runner.flashinfer_decode_workspace_buffer = \
decode_workspace_buffer
graph_runner.flashinfer_decode_wrapper = \
decode_wrapper
self.model, self.attn_backend.get_name(),
self.attn_state.graph_clone(batch_size))
capture_inputs = {
"input_ids":
@ -1476,36 +1366,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
if self.flashinfer_decode_workspace_buffer is None:
self.flashinfer_decode_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_decode_wrapper = \
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_decode_workspace_buffer, "NHD")
self.flashinfer_prefill_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_prefill_wrapper = \
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_prefill_workspace_buffer, "NHD")
model_input.attn_metadata.prefill_wrapper = \
self.flashinfer_prefill_wrapper
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = self.graph_runners[
model_input.
virtual_engine][batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
model_input.attn_metadata.begin_forward()
self.attn_state.begin_forward(model_input)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
@ -1613,22 +1474,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
class CUDAGraphRunner:
def __init__(self, model: nn.Module, backend_name: str):
def __init__(self, model: nn.Module, backend_name: str,
attn_state: AttentionState):
self.model = model
self.backend_name = backend_name
self.attn_state = attn_state
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None
self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
self.flashinfer_decode_wrapper: Optional[
CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
@property
def graph(self):
assert self._graph is not None
@ -1693,25 +1549,13 @@ class CUDAGraphRunner:
torch.cuda.synchronize()
# Save the input and output buffers.
if self.backend_name == "flashinfer":
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
**kwargs,
}
else:
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor":
attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
**kwargs,
}
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
**self.attn_state.get_graph_input_buffers(attn_metadata),
**kwargs,
}
if intermediate_inputs is not None:
self.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
@ -1739,12 +1583,8 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
if self.backend_name != "flashinfer":
self.input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor,
non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
self.attn_state.prepare_graph_input_buffers(self.input_buffers,
attn_metadata)
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs)