[Core] Add AttentionState abstraction (#7663)

This commit is contained in:
Antoni Baum 2024-08-20 11:50:45 -07:00 committed by GitHub
parent c6af027a35
commit 3b682179dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 372 additions and 247 deletions

View File

@ -5,6 +5,7 @@ import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import CommonAttentionState
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import ( from vllm.worker.embedding_model_runner import (
@ -29,7 +30,11 @@ class MockAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]: def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise AttentionMetadataBuilder return AttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(

View File

@ -1,7 +1,7 @@
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType) AttentionState, AttentionType)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
@ -12,5 +12,6 @@ __all__ = [
"AttentionType", "AttentionType",
"AttentionMetadataBuilder", "AttentionMetadataBuilder",
"Attention", "Attention",
"AttentionState",
"get_attn_backend", "get_attn_backend",
] ]

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum, auto from enum import Enum, auto
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
@ -7,7 +8,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerInputBuilderBase)
class AttentionType(Enum): class AttentionType(Enum):
@ -34,6 +37,11 @@ class AttentionBackend(ABC):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError raise NotImplementedError
@staticmethod
@abstractmethod
def get_state_cls() -> Type["AttentionState"]:
raise NotImplementedError
@classmethod @classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs) return cls.get_metadata_cls()(*args, **kwargs)
@ -126,6 +134,47 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata) T = TypeVar("T", bound=AttentionMetadata)
class AttentionState(ABC, Generic[T]):
"""Holds attention backend-specific objects reused during the
lifetime of the model runner."""
@abstractmethod
def __init__(self, runner: "ModelRunnerBase"):
...
@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield
@abstractmethod
def graph_clone(self, batch_size: int) -> "AttentionState[T]":
"""Clone attention state to save in CUDA graph metadata."""
...
@abstractmethod
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
"""Get attention metadata for CUDA graph capture of batch_size."""
...
@abstractmethod
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture."""
...
@abstractmethod
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
attn_metadata: T) -> None:
"""In-place modify input buffers dict for CUDA graph replay."""
...
@abstractmethod
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
"""Prepare state for forward pass."""
...
class AttentionMetadataBuilder(ABC, Generic[T]): class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders.""" """Abstract class for attention metadata builders."""

View File

@ -5,7 +5,8 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.blocksparse_attention.interface import ( from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step) LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
@ -98,6 +99,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder return BlocksparseFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -9,7 +9,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType) AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@ -142,6 +143,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder return FlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -1,14 +1,19 @@
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
try: try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
import vllm.attention.backends.flash_attn # noqa import vllm.attention.backends.flash_attn # noqa
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError: except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch import torch
@ -16,7 +21,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionMetadata,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType) AttentionState, AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
@ -46,6 +51,10 @@ class FlashInferBackend(AttentionBackend):
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder return FlashInferMetadataBuilder
@staticmethod
def get_state_cls() -> Type["FlashInferState"]:
return FlashInferState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
@ -75,6 +84,160 @@ class FlashInferBackend(AttentionBackend):
return [64, 128, 256] return [64, 128, 256]
class FlashInferState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
self._workspace_buffer = None
self._decode_wrapper = None
self._prefill_wrapper = None
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
return self._workspace_buffer
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
return self._prefill_wrapper
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
self._graph_indices_buffer = torch.empty(
max_batch_size * self.runner.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.runner.device)
self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.runner.device)
self._graph_last_page_len_buffer = torch.empty(
max_batch_size, dtype=torch.int32, device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._graph_decode_workspace_buffer
del self._graph_indices_buffer
del self._graph_indptr_buffer
del self._graph_last_page_len_buffer
del self._graph_decode_wrapper
def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
state = self.__class__(self.runner)
state._workspace_buffer = self._graph_decode_workspace_buffer
state._decode_wrapper = self._graph_decode_wrapper
state._prefill_wrapper = self._get_prefill_wrapper()
return state
def graph_capture_get_metadata_for_batch(self, batch_size: int):
assert self._is_graph_capturing
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads >= 4
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(0,
batch_size,
dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
self.runner.block_size,
dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=self._graph_block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.runner.model_config.get_head_size(),
page_size=self.runner.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.runner.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
return attn_metadata
def get_graph_input_buffers(self, attn_metadata):
return {
"slot_mapping": attn_metadata.slot_mapping,
}
def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
return
def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
model_input.attn_metadata.begin_forward()
@dataclass @dataclass
class FlashInferMetadata(AttentionMetadata): class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding # Maximum sequence length among prefill batch. 0 if there are decoding

View File

@ -8,6 +8,7 @@ import torch
from vllm._ipex_ops import ipex_ops from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
@ -28,6 +29,10 @@ class IpexAttnBackend(AttentionBackend):
def get_metadata_cls() -> Type["IpexAttnMetadata"]: def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata return IpexAttnMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -1,11 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Tuple from typing import List, Tuple, Type
import openvino as ov import openvino as ov
import torch import torch
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata)
from vllm.attention.backends.utils import CommonAttentionState
class OpenVINOAttentionBackend(AttentionBackend): class OpenVINOAttentionBackend(AttentionBackend):
@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
def make_metadata(*args, **kwargs) -> "AttentionMetadata": def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError raise NotImplementedError
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
return OpenVINOAttentionMetadata(*args, **kwargs) return OpenVINOAttentionMetadata(*args, **kwargs)

View File

@ -6,6 +6,7 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
class PallasAttentionBackend(AttentionBackend): class PallasAttentionBackend(AttentionBackend):
@ -18,6 +19,10 @@ class PallasAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["PallasMetadata"]: def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata return PallasMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -7,7 +7,8 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -33,6 +34,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
return ROCmFlashAttentionMetadataBuilder return ROCmFlashAttentionMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -8,6 +8,7 @@ from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu from vllm.utils import is_cpu
@ -34,6 +35,10 @@ class TorchSDPABackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata return TorchSDPAMetadata
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -1,12 +1,17 @@
"""Attention backend utils""" """Attention backend utils"""
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
import numpy as np import numpy as np
import torch import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
# Error string(s) for encoder/decoder # Error string(s) for encoder/decoder
# unsupported attention scenarios # unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
@ -269,3 +274,69 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
) )
class CommonAttentionState(AttentionState):
def __init__(self, runner: "ModelRunnerBase"):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_clone(self, batch_size: int) -> "CommonAttentionState":
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(self, batch_size: int):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
)
return attn_metadata
def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
return {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
def prepare_graph_input_buffers(self, input_buffers,
attn_metadata) -> None:
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
def begin_forward(self, model_input) -> None:
return

View File

@ -11,7 +11,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder from vllm.attention.backends.utils import (CommonAttentionState,
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -37,6 +38,10 @@ class XFormersBackend(AttentionBackend):
def get_builder_cls() -> Type["XFormersMetadataBuilder"]: def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
return XFormersMetadataBuilder return XFormersMetadataBuilder
@staticmethod
def get_state_cls() -> Type["CommonAttentionState"]:
return CommonAttentionState
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -11,17 +11,6 @@ except ModuleNotFoundError:
from vllm.attention.backends.rocm_flash_attn import ( from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata) ROCmFlashAttentionMetadata as FlashAttentionMetadata)
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
@ -90,11 +79,6 @@ class TP1DraftModelRunner(ModelRunner):
observability_config=observability_config, observability_config=observability_config,
) )
self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None
def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries): num_queries):
@ -270,36 +254,7 @@ class TP1DraftModelRunner(ModelRunner):
model_input.prompt_adapter_requests, model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping) model_input.prompt_adapter_mapping)
if self.attn_backend.get_name() == "flashinfer": self.attn_state.begin_forward(model_input)
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
if self.flashinfer_decode_workspace_buffer is None:
self.flashinfer_decode_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_decode_wrapper = \
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_decode_workspace_buffer, "NHD")
self.flashinfer_prefill_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_prefill_wrapper = \
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_prefill_workspace_buffer, "NHD")
model_input.attn_metadata.prefill_wrapper = \
self.flashinfer_prefill_wrapper
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = \
self.graph_runners[model_input.
virtual_engine][batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
model_input.attn_metadata.begin_forward()
# Detect exec mode # Detect exec mode
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None

View File

@ -6,6 +6,7 @@ import torch.distributed
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend, get_global_forced_attn_backend,
global_force_attn_backend) global_force_attn_backend)
@ -20,7 +21,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (_PAD_SLOT_ID, GPUModelRunnerBase, from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder, ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
@ -395,7 +396,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
# initialized yet. In this case, we just use a dummy # initialized yet. In this case, we just use a dummy
# slot mapping. # slot mapping.
# In embeddings, the block tables are {seq_id: None}. # In embeddings, the block tables are {seq_id: None}.
cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len) cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
else: else:
for i in range(0, seq_len): for i in range(0, seq_len):
block_number = seq_group_metadata.cross_block_table[ block_number = seq_group_metadata.cross_block_table[

View File

@ -13,19 +13,10 @@ import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeWithPagedKVCacheWrapper = None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
BatchPrefillWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig, ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig) PromptAdapterConfig, SchedulerConfig)
@ -52,8 +43,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, get_kv_cache_torch_dtype, is_hip, flatten_2d_lists, is_hip, is_pin_memory_available)
is_pin_memory_available)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
@ -66,7 +56,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8 LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8 _BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
@ -858,6 +847,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.kv_cache_dtype, self.kv_cache_dtype,
self.block_size, self.block_size,
) if num_attn_heads else None ) if num_attn_heads else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self))
else:
self.attn_state = CommonAttentionState(weakref.proxy(self))
# Multi-modal data support # Multi-modal data support
self.input_registry = input_registry self.input_registry = input_registry
@ -872,11 +866,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
self.flashinfer_decode_workspace_buffer = None
self.flashinfer_decode_wrapper = None
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None
set_cpu_offload_max_bytes( set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3)) int(self.cache_config.cpu_offload_gb * 1024**3))
@ -1203,10 +1192,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID)
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
intermediate_inputs = None intermediate_inputs = None
if not get_pp_group().is_first_rank: if not get_pp_group().is_first_rank:
intermediate_inputs = self.model.make_empty_intermediate_tensors( intermediate_inputs = self.model.make_empty_intermediate_tensors(
@ -1226,102 +1211,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
] ]
if self.attn_backend.get_name() == "flashinfer": with self.attn_state.graph_capture(
# For flashinfer, different batch sizes will share the max_batch_size), graph_capture() as graph_capture_context:
# same workspace buffer.
decode_workspace_buffer = \
torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
indices_buffer = torch.empty(max_batch_size *
self.cache_config.num_gpu_blocks,
dtype=torch.int32,
device=self.device)
indptr_buffer = torch.empty(max_batch_size + 1,
dtype=torch.int32,
device=self.device)
last_page_len_buffer = torch.empty(max_batch_size,
dtype=torch.int32,
device=self.device)
with graph_capture() as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range( for virtual_engine in range(
self.parallel_config.pipeline_parallel_size): self.parallel_config.pipeline_parallel_size):
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
if self.attn_backend.get_name() == "flashinfer": attn_metadata = (
_indptr_buffer = indptr_buffer[:batch_size + 1] self.attn_state.graph_capture_get_metadata_for_batch(
_last_page_len_buffer = last_page_len_buffer[: batch_size))
batch_size]
num_qo_heads = (
self.model_config.get_num_attention_heads(
self.parallel_config))
num_kv_heads = self.model_config.get_num_kv_heads(
self.parallel_config)
if num_qo_heads // num_kv_heads >= 4:
use_tensor_cores = True
else:
use_tensor_cores = False
decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, _indptr_buffer,
indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.kv_cache_dtype, self.model_config.dtype)
paged_kv_indptr_tensor_host = torch.arange(
0, batch_size + 1, dtype=torch.int32)
paged_kv_indices_tensor_host = torch.arange(
0, batch_size, dtype=torch.int32)
paged_kv_last_page_len_tensor_host = torch.full(
(batch_size, ), self.block_size, dtype=torch.int32)
query_start_loc_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
slot_mapping=slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
max_prefill_seq_len=0,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
paged_kv_indices=paged_kv_indices_tensor_host,
paged_kv_last_page_len=
paged_kv_last_page_len_tensor_host,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=None,
query_start_loc=query_start_loc_host,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=True,
decode_wrapper=decode_wrapper,
prefill_wrapper=None)
attn_metadata.begin_forward()
else:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
@ -1339,17 +1238,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
set(), prompt_adapter_mapping) set(), prompt_adapter_mapping)
graph_runner = CUDAGraphRunner( graph_runner = CUDAGraphRunner(
self.model, self.attn_backend.get_name()) self.model, self.attn_backend.get_name(),
self.attn_state.graph_clone(batch_size))
if self.attn_backend.get_name() == "flashinfer":
graph_runner.flashinfer_indptr_buffer = _indptr_buffer
graph_runner.flashinfer_indices_buffer = indices_buffer
graph_runner.flashinfer_last_page_len_buffer = \
_last_page_len_buffer
graph_runner.flashinfer_decode_workspace_buffer = \
decode_workspace_buffer
graph_runner.flashinfer_decode_wrapper = \
decode_wrapper
capture_inputs = { capture_inputs = {
"input_ids": "input_ids":
@ -1476,36 +1366,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.prompt_adapter_requests, model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping) model_input.prompt_adapter_mapping)
if self.attn_backend.get_name() == "flashinfer": self.attn_state.begin_forward(model_input)
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
if self.flashinfer_decode_workspace_buffer is None:
self.flashinfer_decode_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_decode_wrapper = \
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_decode_workspace_buffer, "NHD")
self.flashinfer_prefill_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_prefill_wrapper = \
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_prefill_workspace_buffer, "NHD")
model_input.attn_metadata.prefill_wrapper = \
self.flashinfer_prefill_wrapper
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = self.graph_runners[
model_input.
virtual_engine][batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
model_input.attn_metadata.begin_forward()
# Currently cuda graph is only supported by the decode phase. # Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None assert model_input.attn_metadata is not None
@ -1613,22 +1474,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
class CUDAGraphRunner: class CUDAGraphRunner:
def __init__(self, model: nn.Module, backend_name: str): def __init__(self, model: nn.Module, backend_name: str,
attn_state: AttentionState):
self.model = model self.model = model
self.backend_name = backend_name self.backend_name = backend_name
self.attn_state = attn_state
self.input_buffers: Dict[str, torch.Tensor] = {} self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {}
self._graph: Optional[torch.cuda.CUDAGraph] = None self._graph: Optional[torch.cuda.CUDAGraph] = None
self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
self.flashinfer_decode_wrapper: Optional[
CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
@property @property
def graph(self): def graph(self):
assert self._graph is not None assert self._graph is not None
@ -1693,25 +1549,13 @@ class CUDAGraphRunner:
torch.cuda.synchronize() torch.cuda.synchronize()
# Save the input and output buffers. # Save the input and output buffers.
if self.backend_name == "flashinfer": self.input_buffers = {
self.input_buffers = { "input_ids": input_ids,
"input_ids": input_ids, "positions": positions,
"positions": positions, "kv_caches": kv_caches,
"kv_caches": kv_caches, **self.attn_state.get_graph_input_buffers(attn_metadata),
"slot_mapping": attn_metadata.slot_mapping, **kwargs,
**kwargs, }
}
else:
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor":
attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
**kwargs,
}
if intermediate_inputs is not None: if intermediate_inputs is not None:
self.input_buffers.update(intermediate_inputs.tensors) self.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
@ -1739,12 +1583,8 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) non_blocking=True)
if self.backend_name != "flashinfer": self.attn_state.prepare_graph_input_buffers(self.input_buffers,
self.input_buffers["seq_lens_tensor"].copy_( attn_metadata)
attn_metadata.decode_metadata.seq_lens_tensor,
non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if "seqlen_agnostic_capture_inputs" in self.input_buffers: if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_inputs_before_cuda_graphs(self.input_buffers, self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
**kwargs) **kwargs)