mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:34:57 +08:00
[Core] Add AttentionState abstraction (#7663)
This commit is contained in:
parent
c6af027a35
commit
3b682179dd
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.worker.embedding_model_runner import (
|
||||
@ -29,7 +30,11 @@ class MockAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
|
||||
raise AttentionMetadataBuilder
|
||||
return AttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
@ -12,5 +12,6 @@ __all__ = [
|
||||
"AttentionType",
|
||||
"AttentionMetadataBuilder",
|
||||
"Attention",
|
||||
"AttentionState",
|
||||
"get_attn_backend",
|
||||
]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum, auto
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
||||
@ -7,7 +8,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerInputBuilderBase)
|
||||
|
||||
|
||||
class AttentionType(Enum):
|
||||
@ -34,6 +37,11 @@ class AttentionBackend(ABC):
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_state_cls() -> Type["AttentionState"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||
return cls.get_metadata_cls()(*args, **kwargs)
|
||||
@ -126,6 +134,47 @@ class AttentionMetadata:
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
class AttentionState(ABC, Generic[T]):
|
||||
"""Holds attention backend-specific objects reused during the
|
||||
lifetime of the model runner."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
"""Context manager used when capturing CUDA graphs."""
|
||||
yield
|
||||
|
||||
@abstractmethod
|
||||
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
|
||||
"""Clone attention state to save in CUDA graph metadata."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
|
||||
"""Get attention metadata for CUDA graph capture of batch_size."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
|
||||
"""Get attention-specific input buffers for CUDA graph capture."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
|
||||
attn_metadata: T) -> None:
|
||||
"""In-place modify input buffers dict for CUDA graph replay."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
|
||||
"""Prepare state for forward pass."""
|
||||
...
|
||||
|
||||
|
||||
class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
"""Abstract class for attention metadata builders."""
|
||||
|
||||
|
||||
@ -5,7 +5,8 @@ import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonMetadataBuilder
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.blocksparse_attention.interface import (
|
||||
LocalStridedBlockSparseAttn, get_head_sliding_step)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
@ -98,6 +99,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
|
||||
return BlocksparseFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
|
||||
@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
|
||||
compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
@ -142,6 +143,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
|
||||
return FlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
|
||||
@ -1,14 +1,19 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
|
||||
import vllm.attention.backends.flash_attn # noqa
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
import torch
|
||||
|
||||
@ -16,7 +21,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType)
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
@ -46,6 +51,10 @@ class FlashInferBackend(AttentionBackend):
|
||||
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
|
||||
return FlashInferMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["FlashInferState"]:
|
||||
return FlashInferState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
@ -75,6 +84,160 @@ class FlashInferBackend(AttentionBackend):
|
||||
return [64, 128, 256]
|
||||
|
||||
|
||||
class FlashInferState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
self._workspace_buffer = None
|
||||
self._decode_wrapper = None
|
||||
self._prefill_wrapper = None
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
self._workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.runner.device)
|
||||
return self._workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
if self._prefill_wrapper is None:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(), "NHD")
|
||||
return self._prefill_wrapper
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
if self._decode_wrapper is None:
|
||||
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
"NHD",
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
return self._decode_wrapper
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
self._graph_decode_wrapper = None
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
|
||||
self._graph_indices_buffer = torch.empty(
|
||||
max_batch_size * self.runner.cache_config.num_gpu_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_last_page_len_buffer = torch.empty(
|
||||
max_batch_size, dtype=torch.int32, device=self.runner.device)
|
||||
yield
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
del self._graph_decode_workspace_buffer
|
||||
del self._graph_indices_buffer
|
||||
del self._graph_indptr_buffer
|
||||
del self._graph_last_page_len_buffer
|
||||
del self._graph_decode_wrapper
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
state = self.__class__(self.runner)
|
||||
state._workspace_buffer = self._graph_decode_workspace_buffer
|
||||
state._decode_wrapper = self._graph_decode_wrapper
|
||||
state._prefill_wrapper = self._get_prefill_wrapper()
|
||||
return state
|
||||
|
||||
def graph_capture_get_metadata_for_batch(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
|
||||
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
|
||||
|
||||
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config))
|
||||
num_kv_heads = self.runner.model_config.get_num_kv_heads(
|
||||
self.runner.parallel_config)
|
||||
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
|
||||
self._graph_decode_wrapper = \
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
||||
self._graph_decode_workspace_buffer, _indptr_buffer,
|
||||
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
|
||||
use_tensor_cores)
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
|
||||
|
||||
paged_kv_indptr_tensor_host = torch.arange(0,
|
||||
batch_size + 1,
|
||||
dtype=torch.int32)
|
||||
paged_kv_indices_tensor_host = torch.arange(0,
|
||||
batch_size,
|
||||
dtype=torch.int32)
|
||||
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
|
||||
self.runner.block_size,
|
||||
dtype=torch.int32)
|
||||
query_start_loc_host = torch.arange(0,
|
||||
batch_size + 1,
|
||||
dtype=torch.int32)
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
max_prefill_seq_len=0,
|
||||
block_tables=self._graph_block_tables,
|
||||
paged_kv_indptr=paged_kv_indptr_tensor_host,
|
||||
paged_kv_indices=paged_kv_indices_tensor_host,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
page_size=self.runner.block_size,
|
||||
seq_start_loc=None,
|
||||
query_start_loc=query_start_loc_host,
|
||||
device=self.runner.device,
|
||||
data_type=kv_cache_dtype,
|
||||
use_cuda_graph=True,
|
||||
decode_wrapper=self._graph_decode_wrapper,
|
||||
prefill_wrapper=None)
|
||||
attn_metadata.begin_forward()
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self, attn_metadata):
|
||||
return {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
}
|
||||
|
||||
def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
|
||||
return
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
assert not self._is_graph_capturing
|
||||
state = self
|
||||
if model_input.attn_metadata.use_cuda_graph:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
state = (self.runner.graph_runners[model_input.virtual_engine]
|
||||
[batch_size].attn_state)
|
||||
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
|
||||
)
|
||||
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
|
||||
model_input.attn_metadata.begin_forward()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata(AttentionMetadata):
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
@ -28,6 +29,10 @@ class IpexAttnBackend(AttentionBackend):
|
||||
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
|
||||
return IpexAttnMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import openvino as ov
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
|
||||
class OpenVINOAttentionBackend(AttentionBackend):
|
||||
@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
|
||||
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
|
||||
return OpenVINOAttentionMetadata(*args, **kwargs)
|
||||
|
||||
@ -6,6 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
@ -18,6 +19,10 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonMetadataBuilder
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -33,6 +34,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
|
||||
return ROCmFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
|
||||
@ -8,6 +8,7 @@ from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
@ -34,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
|
||||
@ -1,12 +1,17 @@
|
||||
"""Attention backend utils"""
|
||||
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
||||
AttentionState)
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
# unsupported attention scenarios
|
||||
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
@ -269,3 +274,69 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
|
||||
|
||||
class CommonAttentionState(AttentionState):
|
||||
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
self._is_graph_capturing = True
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
yield
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
|
||||
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
|
||||
return {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
|
||||
def prepare_graph_input_buffers(self, input_buffers,
|
||||
attn_metadata) -> None:
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
|
||||
def begin_forward(self, model_input) -> None:
|
||||
return
|
||||
|
||||
@ -11,7 +11,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.attention.backends.utils import CommonMetadataBuilder
|
||||
from vllm.attention.backends.utils import (CommonAttentionState,
|
||||
CommonMetadataBuilder)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -37,6 +38,10 @@ class XFormersBackend(AttentionBackend):
|
||||
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
|
||||
return XFormersMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
|
||||
@ -11,17 +11,6 @@ except ModuleNotFoundError:
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
@ -90,11 +79,6 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
observability_config=observability_config,
|
||||
)
|
||||
|
||||
self.flashinfer_decode_workspace_buffer = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
self.flashinfer_prefill_workspace_buffer = None
|
||||
self.flashinfer_prefill_wrapper = None
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
@ -270,36 +254,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
assert model_input.attn_metadata is not None
|
||||
assert model_input.input_tokens is not None
|
||||
if self.flashinfer_decode_workspace_buffer is None:
|
||||
self.flashinfer_decode_workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.flashinfer_decode_wrapper = \
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_decode_workspace_buffer, "NHD")
|
||||
self.flashinfer_prefill_workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.flashinfer_prefill_wrapper = \
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_prefill_workspace_buffer, "NHD")
|
||||
|
||||
model_input.attn_metadata.prefill_wrapper = \
|
||||
self.flashinfer_prefill_wrapper
|
||||
if model_input.attn_metadata.use_cuda_graph:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
model_input.attn_metadata.decode_wrapper = \
|
||||
self.graph_runners[model_input.
|
||||
virtual_engine][batch_size].flashinfer_decode_wrapper
|
||||
else:
|
||||
model_input.attn_metadata.decode_wrapper = \
|
||||
self.flashinfer_decode_wrapper
|
||||
model_input.attn_metadata.begin_forward()
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Detect exec mode
|
||||
assert model_input.attn_metadata is not None
|
||||
|
||||
@ -6,6 +6,7 @@ import torch.distributed
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||
get_global_forced_attn_backend,
|
||||
global_force_attn_backend)
|
||||
@ -20,7 +21,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase,
|
||||
from vllm.worker.model_runner import (GPUModelRunnerBase,
|
||||
ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
@ -395,7 +396,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
|
||||
else:
|
||||
for i in range(0, seq_len):
|
||||
block_number = seq_group_metadata.cross_block_table[
|
||||
|
||||
@ -13,19 +13,10 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
||||
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeWithPagedKVCacheWrapper = None
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
||||
BatchPrefillWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionState
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
@ -52,8 +43,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
|
||||
flatten_2d_lists, get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available)
|
||||
flatten_2d_lists, is_hip, is_pin_memory_available)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
@ -66,7 +56,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
LORA_WARMUP_RANK = 8
|
||||
_BATCH_SIZE_ALIGNMENT = 8
|
||||
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
||||
@ -858,6 +847,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
) if num_attn_heads else None
|
||||
if self.attn_backend:
|
||||
self.attn_state = self.attn_backend.get_state_cls()(
|
||||
weakref.proxy(self))
|
||||
else:
|
||||
self.attn_state = CommonAttentionState(weakref.proxy(self))
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
@ -872,11 +866,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
|
||||
|
||||
self.flashinfer_decode_workspace_buffer = None
|
||||
self.flashinfer_decode_wrapper = None
|
||||
self.flashinfer_prefill_workspace_buffer = None
|
||||
self.flashinfer_prefill_wrapper = None
|
||||
|
||||
set_cpu_offload_max_bytes(
|
||||
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||
|
||||
@ -1203,10 +1192,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
|
||||
slot_mapping.fill_(_PAD_SLOT_ID)
|
||||
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
||||
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
||||
intermediate_inputs = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
||||
@ -1226,102 +1211,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
# For flashinfer, different batch sizes will share the
|
||||
# same workspace buffer.
|
||||
decode_workspace_buffer = \
|
||||
torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
indices_buffer = torch.empty(max_batch_size *
|
||||
self.cache_config.num_gpu_blocks,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
indptr_buffer = torch.empty(max_batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
last_page_len_buffer = torch.empty(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
with graph_capture() as graph_capture_context:
|
||||
with self.attn_state.graph_capture(
|
||||
max_batch_size), graph_capture() as graph_capture_context:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for virtual_engine in range(
|
||||
self.parallel_config.pipeline_parallel_size):
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
_indptr_buffer = indptr_buffer[:batch_size + 1]
|
||||
_last_page_len_buffer = last_page_len_buffer[:
|
||||
batch_size]
|
||||
|
||||
num_qo_heads = (
|
||||
self.model_config.get_num_attention_heads(
|
||||
self.parallel_config))
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(
|
||||
self.parallel_config)
|
||||
if num_qo_heads // num_kv_heads >= 4:
|
||||
use_tensor_cores = True
|
||||
else:
|
||||
use_tensor_cores = False
|
||||
decode_wrapper = \
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
||||
decode_workspace_buffer, _indptr_buffer,
|
||||
indices_buffer, _last_page_len_buffer, "NHD",
|
||||
use_tensor_cores)
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(
|
||||
self.kv_cache_dtype, self.model_config.dtype)
|
||||
|
||||
paged_kv_indptr_tensor_host = torch.arange(
|
||||
0, batch_size + 1, dtype=torch.int32)
|
||||
paged_kv_indices_tensor_host = torch.arange(
|
||||
0, batch_size, dtype=torch.int32)
|
||||
paged_kv_last_page_len_tensor_host = torch.full(
|
||||
(batch_size, ), self.block_size, dtype=torch.int32)
|
||||
query_start_loc_host = torch.arange(0,
|
||||
batch_size + 1,
|
||||
dtype=torch.int32)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
max_prefill_seq_len=0,
|
||||
block_tables=block_tables,
|
||||
paged_kv_indptr=paged_kv_indptr_tensor_host,
|
||||
paged_kv_indices=paged_kv_indices_tensor_host,
|
||||
paged_kv_last_page_len=
|
||||
paged_kv_last_page_len_tensor_host,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
page_size=self.block_size,
|
||||
seq_start_loc=None,
|
||||
query_start_loc=query_start_loc_host,
|
||||
device=self.device,
|
||||
data_type=kv_cache_dtype,
|
||||
use_cuda_graph=True,
|
||||
decode_wrapper=decode_wrapper,
|
||||
prefill_wrapper=None)
|
||||
attn_metadata.begin_forward()
|
||||
else:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens[:batch_size],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
attn_metadata = (
|
||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||
batch_size))
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
@ -1339,17 +1238,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
set(), prompt_adapter_mapping)
|
||||
|
||||
graph_runner = CUDAGraphRunner(
|
||||
self.model, self.attn_backend.get_name())
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
graph_runner.flashinfer_indptr_buffer = _indptr_buffer
|
||||
graph_runner.flashinfer_indices_buffer = indices_buffer
|
||||
graph_runner.flashinfer_last_page_len_buffer = \
|
||||
_last_page_len_buffer
|
||||
graph_runner.flashinfer_decode_workspace_buffer = \
|
||||
decode_workspace_buffer
|
||||
graph_runner.flashinfer_decode_wrapper = \
|
||||
decode_wrapper
|
||||
self.model, self.attn_backend.get_name(),
|
||||
self.attn_state.graph_clone(batch_size))
|
||||
|
||||
capture_inputs = {
|
||||
"input_ids":
|
||||
@ -1476,36 +1366,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
assert model_input.attn_metadata is not None
|
||||
assert model_input.input_tokens is not None
|
||||
if self.flashinfer_decode_workspace_buffer is None:
|
||||
self.flashinfer_decode_workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.flashinfer_decode_wrapper = \
|
||||
BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.flashinfer_decode_workspace_buffer, "NHD")
|
||||
self.flashinfer_prefill_workspace_buffer = torch.empty(
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device=self.device)
|
||||
self.flashinfer_prefill_wrapper = \
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_prefill_workspace_buffer, "NHD")
|
||||
|
||||
model_input.attn_metadata.prefill_wrapper = \
|
||||
self.flashinfer_prefill_wrapper
|
||||
if model_input.attn_metadata.use_cuda_graph:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
model_input.attn_metadata.decode_wrapper = self.graph_runners[
|
||||
model_input.
|
||||
virtual_engine][batch_size].flashinfer_decode_wrapper
|
||||
else:
|
||||
model_input.attn_metadata.decode_wrapper = \
|
||||
self.flashinfer_decode_wrapper
|
||||
model_input.attn_metadata.begin_forward()
|
||||
self.attn_state.begin_forward(model_input)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
@ -1613,22 +1474,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
class CUDAGraphRunner:
|
||||
|
||||
def __init__(self, model: nn.Module, backend_name: str):
|
||||
def __init__(self, model: nn.Module, backend_name: str,
|
||||
attn_state: AttentionState):
|
||||
self.model = model
|
||||
self.backend_name = backend_name
|
||||
self.attn_state = attn_state
|
||||
|
||||
self.input_buffers: Dict[str, torch.Tensor] = {}
|
||||
self.output_buffers: Dict[str, torch.Tensor] = {}
|
||||
|
||||
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
||||
|
||||
self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
|
||||
self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
|
||||
self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
|
||||
self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
|
||||
self.flashinfer_decode_wrapper: Optional[
|
||||
CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
|
||||
|
||||
@property
|
||||
def graph(self):
|
||||
assert self._graph is not None
|
||||
@ -1693,25 +1549,13 @@ class CUDAGraphRunner:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Save the input and output buffers.
|
||||
if self.backend_name == "flashinfer":
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
**kwargs,
|
||||
}
|
||||
else:
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor":
|
||||
attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
**kwargs,
|
||||
}
|
||||
self.input_buffers = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
**self.attn_state.get_graph_input_buffers(attn_metadata),
|
||||
**kwargs,
|
||||
}
|
||||
if intermediate_inputs is not None:
|
||||
self.input_buffers.update(intermediate_inputs.tensors)
|
||||
if get_pp_group().is_last_rank:
|
||||
@ -1739,12 +1583,8 @@ class CUDAGraphRunner:
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
||||
non_blocking=True)
|
||||
if self.backend_name != "flashinfer":
|
||||
self.input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
self.attn_state.prepare_graph_input_buffers(self.input_buffers,
|
||||
attn_metadata)
|
||||
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
||||
**kwargs)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user