[Core] Refactor _prepare_model_input_tensors - take 2 (#6164)

This commit is contained in:
Cody Yu 2024-07-17 09:37:16 -07:00 committed by GitHub
parent a9a2e74d21
commit 2fa4623d9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1050 additions and 470 deletions

View File

@ -3,7 +3,7 @@ from typing import List, Tuple, Type
import torch import torch
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
@ -26,6 +26,10 @@ class MockAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata return AttentionMetadata
@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise AttentionMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,

View File

@ -1,5 +1,6 @@
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
@ -7,6 +8,7 @@ __all__ = [
"Attention", "Attention",
"AttentionBackend", "AttentionBackend",
"AttentionMetadata", "AttentionMetadata",
"AttentionMetadataBuilder",
"Attention", "Attention",
"get_attn_backend", "get_attn_backend",
] ]

View File

@ -1,11 +1,15 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from enum import Enum, auto from enum import Enum, auto
from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
TypeVar) Tuple, Type, TypeVar)
import torch import torch
if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner_base import ModelRunnerInputBuilderBase
class AttentionType(Enum): class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V DECODER = auto() # Decoder attention between previous layer Q/K/V
@ -35,6 +39,16 @@ class AttentionBackend(ABC):
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs) return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_kv_cache_shape( def get_kv_cache_shape(
@ -110,6 +124,33 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata) T = TypeVar("T", bound=AttentionMetadata)
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self, input_builder) -> None:
raise NotImplementedError
@abstractmethod
def add_seq_group(self, seq_group_metadata: "SequenceGroupMetadata",
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata and update
corresponding fields (in Python objects).
"""
raise NotImplementedError
@abstractmethod
def build(self, runner: "ModelRunnerInputBuilderBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError
class AttentionImpl(ABC, Generic[T]): class AttentionImpl(ABC, Generic[T]):
@abstractmethod @abstractmethod

View File

@ -5,6 +5,7 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.blocksparse_attention.interface import ( from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step) LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
@ -93,6 +94,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata return BlocksparseFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
@ -244,6 +249,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
return self._cached_decode_metadata return self._cached_decode_metadata
class BlocksparseFlashAttentionMetadataBuilder(
CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]):
_metadata_cls = BlocksparseFlashAttentionMetadata
class BlocksparseFlashAttentionImpl(AttentionImpl): class BlocksparseFlashAttentionImpl(AttentionImpl):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:

View File

@ -1,13 +1,24 @@
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
@ -28,6 +39,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata return FlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
@ -184,6 +199,170 @@ class FlashAttentionMetadata(AttentionMetadata):
return self._cached_decode_metadata return self._cached_decode_metadata
class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors."""
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
class FlashAttentionImpl(AttentionImpl): class FlashAttentionImpl(AttentionImpl):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
try: try:
from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer import BatchDecodeWithPagedKVCacheWrapper
@ -14,7 +14,18 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import get_kv_cache_torch_dtype, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
@ -31,6 +42,10 @@ class FlashInferBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashInferMetadata return FlashInferMetadata
@staticmethod
def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
return FlashInferMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
@ -188,6 +203,225 @@ class FlashInferMetadata(AttentionMetadata):
return self return self
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
self.paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first requests page indices in the paged_kv_indices list.
self.paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
self.paged_kv_last_page_len: List[int] = []
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int],
prefix_cache_hit: bool, chunked_prefill_enabled: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
is_profile_run = is_block_tables_empty(block_tables)
# Compute slot mapping.
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
# It is not necessary to add paged_kv_indices, paged_kv_indptr,
# and paged_kv_last_page_len for profile run because we will
# create dummy inputs.
if is_profile_run:
return
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
block_table = block_tables[seq_id]
self.paged_kv_indices.extend(block_table[:block_table_bound])
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
self.paged_kv_last_page_len.append(last_page_len)
def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
cuda_graph_pad_size: int, batch_size: int):
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
last_paged_kv_indptr = self.paged_kv_indptr[-1]
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
cuda_graph_pad_size)
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr,
device="cpu",
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype(runner.kv_cache_dtype,
runner.model_config.dtype)
return FlashInferMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
num_qo_heads=runner.model_config.get_num_attention_heads(
runner.parallel_config),
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_dim=runner.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=seq_start_loc,
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)
class FlashInferImpl(AttentionImpl): class FlashInferImpl(AttentionImpl):
def __init__( def __init__(

View File

@ -7,6 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -28,6 +29,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata return ROCmFlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
return ROCmFlashAttentionMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
@ -166,6 +171,12 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
return self._cached_decode_metadata return self._cached_decode_metadata
class ROCmFlashAttentionMetadataBuilder(
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
_metadata_cls = ROCmFlashAttentionMetadata
def _make_alibi_bias(alibi_slopes: torch.Tensor, def _make_alibi_bias(alibi_slopes: torch.Tensor,
dtype: torch.dtype, dtype: torch.dtype,
seq_lens: Optional[List[int]], seq_lens: Optional[List[int]],

View File

@ -1,7 +1,239 @@
"""Attention backend utils""" """Attention backend utils"""
from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
import torch
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
# Error string(s) for encoder/decoder # Error string(s) for encoder/decoder
# unsupported attention scenarios # unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
"with encoder/decoder models.") "with encoder/decoder models.")
PAD_SLOT_ID = -1
if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)
def is_block_tables_empty(block_tables: Union[None, Dict]):
"""
Check if block_tables is None or a dictionary with all None values.
"""
if block_tables is None:
return True
if isinstance(block_tables, dict) and all(
value is None for value in block_tables.values()):
return True
return False
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
context_len: int, sliding_window: int,
use_v2_block_manager: bool):
"""
Compute the start index of slot mapping.
"""
start_idx = 0
if is_prompt and sliding_window is not None:
assert use_v2_block_manager or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - sliding_window)
return start_idx
def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
seq_id: int, seq_len: int, context_len: int,
start_idx: int, block_size: int,
block_tables: Dict[int, List[int]]):
"""
Compute slot mapping.
"""
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([PAD_SLOT_ID] * seq_len)
return
# Mask the [0, start_idx) tokens of the prompt with
# PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table = block_tables[seq_id]
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
_metadata_cls: Type[TAttentionMetadata]
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
curr_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables
computed_block_nums = seq_group_metadata.computed_block_nums
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
curr_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
block_table = computed_block_nums
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)
def build(self, runner: "GPUModelRunnerBase", seq_lens: List[int],
query_lens: List[int], cuda_graph_pad_size: int,
batch_size: int):
device = runner.device
use_captured_graph = cuda_graph_pad_size != -1
logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap "
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
"Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)

View File

@ -11,6 +11,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonMetadataBuilder
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -32,6 +33,10 @@ class XFormersBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata return XFormersMetadata
@staticmethod
def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
return XFormersMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
@ -362,6 +367,11 @@ def _get_seq_len_block_table_args(
raise AttributeError(f"Invalid attention type {str(attn_type)}") raise AttributeError(f"Invalid attention type {str(attn_type)}")
class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
_metadata_cls = XFormersMetadata
class XFormersImpl(AttentionImpl[XFormersMetadata]): class XFormersImpl(AttentionImpl[XFormersMetadata]):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:

View File

@ -7,6 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu from vllm.utils import is_cpu, is_hip, is_openvino, is_tpu, is_xpu
logger = init_logger(__name__) logger = init_logger(__name__)
@ -136,7 +137,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
if torch.cuda.get_device_capability()[0] != 9: if current_platform.get_device_capability()[0] != 9:
# not Instinct series GPUs. # not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
@ -145,7 +146,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN: if selected_backend == _Backend.FLASH_ATTN:
if torch.cuda.get_device_capability()[0] < 8: if current_platform.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs. # Volta and Turing NVIDIA GPUs.
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing " "Cannot use FlashAttention-2 backend for Volta and Turing "

View File

@ -2,6 +2,7 @@ import dataclasses
import gc import gc
import time import time
import warnings import warnings
import weakref
from collections import defaultdict from collections import defaultdict
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union) Tuple, Type, TypeVar, Union)
@ -48,9 +49,9 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput, from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata) SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad) is_pin_memory_available)
from vllm.worker.model_runner_base import ( from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
_add_attn_metadata_broadcastable_dict, _add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict, _init_attn_metadata_from_tensor_dict,
@ -165,6 +166,298 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
return cls(**tensor_dict) return cls(**tensor_dict)
class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
"""TBA"""
def __init__(self,
runner: "GPUModelRunnerBase",
finished_requests_ids: Optional[List[str]] = None):
super().__init__()
self.runner = runner
self.model_input_cls = self.runner._model_input_cls
self.attn_backend = self.runner.attn_backend
self.scheduler_config = self.runner.scheduler_config
self.sliding_window = self.runner.sliding_window
self.block_size = self.runner.block_size
self.enable_lora = self.runner.lora_config is not None
self.enable_prompt_adapter = (self.runner.prompt_adapter_config
is not None)
self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper
self.finished_requests_ids = finished_requests_ids
self.decode_only = True
# Common inputs.
self.input_tokens: List[int] = []
self.input_positions: List[int] = []
self.seq_lens: List[int] = []
self.query_lens: List[int] = []
self.max_decode_seq_len: int = 0
self.request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
# LoRA inputs.
self.lora_index_mapping: List[int] = []
self.lora_prompt_mapping: List[int] = []
self.lora_requests: Set[LoRARequest] = set()
# Prompt adapter inputs.
self.prompt_adapter_index_mapping: List[int] = []
self.prompt_adapter_prompt_mapping: List[int] = []
self.prompt_adapter_requests: Set[PromptAdapterRequest] = set()
# Multi-modal inputs.
self.multi_modal_inputs_list: List[MultiModalInputs] = []
# Attention metadata inputs.
self.attn_metadata_builder = self.attn_backend.make_metadata_builder(
self)
# Engine/Model configurations.
self.chunked_prefill_enabled = (
self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled)
if self.sliding_window is not None:
self.sliding_window_blocks = (
self.sliding_window + self.block_size - 1) // self.block_size
self.block_aligned_sliding_window = \
self.sliding_window_blocks * self.block_size
def _compute_len_for_sliding_window(self, seq_len: int):
curr_sliding_window_blocks = 0
sliding_seq_len = seq_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if self.sliding_window is not None:
curr_sliding_window_blocks = self.sliding_window_blocks
if self.scheduler_config.use_v2_block_manager:
# number of elements in last block
suff_len = seq_len % self.block_size
sliding_seq_len = min(
seq_len, self.block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_blocks += 1
else:
sliding_seq_len = min(seq_len, self.sliding_window)
return curr_sliding_window_blocks, sliding_seq_len
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
seq_ids = list(seq_group_metadata.seq_data.keys())
n_seqs = len(seq_ids)
is_prompt = seq_group_metadata.is_prompt
token_chunk_size = seq_group_metadata.token_chunk_size
if is_prompt:
assert n_seqs == 1
self.decode_only = False
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
self.request_ids_to_seq_ids[seq_group_metadata.request_id] = []
# The number of input tokens in each sequence.
token_lens: List[int] = []
# The number of tokens that are already computed.
context_lens: List[int] = []
# The current sliding window block for each sequence.
curr_sliding_window_blocks: List[int] = []
# The original sequence length (before applying sliding window)
# for each sequence.
orig_seq_lens: List[int] = []
# The sequence length (may be capped to the sliding window).
curr_seq_lens: List[int] = []
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
self.request_ids_to_seq_ids[seq_group_metadata.request_id].append(
seq_id)
computed_block_nums = seq_group_metadata.computed_block_nums
# Check if hit prefix cache (i.e., some blocks are already computed)
# Note that prefix caching does not support sliding window.
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None and is_prompt)
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")
# Compute context length (the number of tokens that are
# already computed) and sequence length (total number of tokens).
seq_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len = seq_len - 1
seq_len = min(seq_len, context_len + token_chunk_size)
# Compute tokens.
if is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
tokens = tokens[context_len:]
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
if is_prompt:
curr_sliding_window_block = 0
sliding_seq_len = seq_len
query_len = seq_len - context_len
else:
curr_sliding_window_block, sliding_seq_len = (
self._compute_len_for_sliding_window(seq_len))
query_len = 1
self.seq_lens.append(sliding_seq_len)
if not is_prompt:
self.max_decode_seq_len = max(self.max_decode_seq_len,
sliding_seq_len)
self.query_lens.append(query_len)
self.input_tokens.extend(tokens)
self.input_positions.extend(list(range(context_len, seq_len)))
# Intermediate data of the current sequence group for
# the attention metadata.
token_lens.append(len(tokens))
context_lens.append(context_len)
curr_seq_lens.append(sliding_seq_len)
curr_sliding_window_blocks.append(curr_sliding_window_block)
orig_seq_lens.append(seq_len)
# Update attention metadata. Note that input builder attributes
# (self.xxx) include all added sequences, so we need to slice
# the last n_seqs sequences.
self.attn_metadata_builder.add_seq_group(
seq_group_metadata, token_lens, orig_seq_lens, curr_seq_lens,
self.query_lens[-n_seqs:], context_lens,
curr_sliding_window_blocks, prefix_cache_hit,
self.chunked_prefill_enabled)
# LoRA data.
if self.enable_lora:
lora_id = seq_group_metadata.lora_int_id
for query_len in self.query_lens[-n_seqs:]:
if lora_id > 0:
self.lora_requests.add(seq_group_metadata.lora_request)
self.lora_index_mapping += [lora_id] * query_len
self.lora_prompt_mapping.extend(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
is not None else 1))
# Prompt adapter data. Note that when is_prompt=True,
# we expect only one sequence in the group.
if self.enable_prompt_adapter:
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if prompt_adapter_id > 0 and is_prompt:
query_len = self.query_lens[-1]
self.prompt_adapter_requests.add(
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.\
prompt_adapter_num_virtual_tokens
pm = [prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
self.prompt_adapter_index_mapping += pm
self.prompt_adapter_prompt_mapping.extend(
[prompt_adapter_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
# Multi-modal data.
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
mm_kwargs = self.multi_modal_input_mapper(mm_data)
self.multi_modal_inputs_list.append(mm_kwargs)
def build(self) -> ModelInputForGPU:
if not self.input_tokens:
return self.model_input_cls()
batch_size = len(self.input_tokens)
use_captured_graph = (
self.decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and self.max_decode_seq_len <= self.runner.max_seq_len_to_capture)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size = -1
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
batch_size = graph_batch_size
# Tokens and positions.
self.input_tokens.extend([0] * cuda_graph_pad_size)
self.input_positions.extend([0] * cuda_graph_pad_size)
input_tokens_tensor = torch.tensor(self.input_tokens,
dtype=torch.long,
device=self.runner.device)
input_positions_tensor = torch.tensor(self.input_positions,
dtype=torch.long,
device=self.runner.device)
# Sequence and query lengths.
self.seq_lens.extend([1] * cuda_graph_pad_size)
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
self.runner, self.seq_lens, self.query_lens, cuda_graph_pad_size,
batch_size)
# LoRA data.
if self.enable_lora:
self.lora_index_mapping.extend([0] * cuda_graph_pad_size)
lora_mapping = LoRAMapping(
self.lora_index_mapping,
self.lora_prompt_mapping,
)
else:
lora_mapping = None
# Prompt adapter data.
if self.enable_prompt_adapter:
self.prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
prompt_adapter_mapping = PromptAdapterMapping(
self.prompt_adapter_index_mapping,
self.prompt_adapter_prompt_mapping,
)
else:
prompt_adapter_mapping = None
# Multi-modal data.
multi_modal_kwargs = MultiModalInputs.batch(
self.multi_modal_inputs_list, device=self.runner.device)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=self.seq_lens,
query_lens=self.query_lens,
lora_mapping=lora_mapping,
lora_requests=self.lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=self.request_ids_to_seq_ids,
finished_requests_ids=self.finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=self.prompt_adapter_requests)
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
""" """
Helper class for shared methods between GPU model runners. Helper class for shared methods between GPU model runners.
@ -368,464 +661,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
If cuda graph is required, this API automatically pads inputs. If cuda graph is required, this API automatically pads inputs.
""" """
input_tokens: List[int] = [] builder = ModelInputForGPUBuilder(weakref.proxy(self),
input_positions: List[int] = [] finished_requests_ids)
slot_mapping: List[int] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
prompt_adapter_index_mapping: List[int] = []
prompt_adapter_prompt_mapping: List[int] = []
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
seq_lens: List[int] = []
prefill_seq_lens: List[int] = []
decode_seq_lens: List[int] = []
context_lens: List[int] = []
query_lens: List[int] = []
block_tables: List[List[int]] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
request_ids_to_seq_ids: Dict[str, List[int]] = defaultdict(list)
decode_only = True
num_prefills = 0
num_prefill_tokens = 0
num_decode_tokens = 0
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices: List[int] = []
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first requests page indices in the paged_kv_indices list.
paged_kv_indptr: List[int] = [0]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0:
return self._model_input_cls()
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window + self.block_size -
1) // self.block_size
block_aligned_sliding_window = \
sliding_window_blocks * self.block_size
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys()) builder.add_seq_group(seq_group_metadata)
is_prompt = seq_group_metadata.is_prompt return builder.build() # type: ignore
for seq_id in seq_ids:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
seq_data = seq_group_metadata.seq_data[seq_id]
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len = seq_data.get_len() - 1
seq_len = min(
seq_data.get_len(),
context_len + seq_group_metadata.token_chunk_size)
if is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and is_prompt)
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
curr_sliding_window_blocks = None
sliding_seq_len = seq_len
sliding_context_len = context_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if (self.sliding_window is not None and not is_prompt):
curr_sliding_window_blocks = sliding_window_blocks
if self.scheduler_config.use_v2_block_manager:
# number of elements in last block
suff_len = seq_len % self.block_size
sliding_seq_len = min(
seq_len, block_aligned_sliding_window + suff_len)
if suff_len > 0:
curr_sliding_window_blocks += 1
else:
sliding_seq_len = min(seq_len, self.sliding_window)
sliding_context_len = sliding_seq_len - 1
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
tokens = tokens[context_len:]
# need to think what to set it to when we have both sliding
# window and prefix caching...
assert self.sliding_window is None, \
"Prefix caching is not supported with sliding window"
sliding_context_len = context_len
if self.attn_backend.get_name() == "flash-attn":
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# TODO(woosuk): This is a temporary fix. We should
# provide a unified interface for different backends.
block_table = seq_group_metadata.block_tables[seq_id]
else:
block_table = computed_block_nums
elif (self.scheduler_config.chunked_prefill_enabled
or not is_prompt):
if seq_group_metadata.block_tables is not None:
# chunked prefill or decode
block_table = seq_group_metadata.block_tables[seq_id]
if curr_sliding_window_blocks is not None:
block_table = block_table[
-curr_sliding_window_blocks:]
else:
# Only happens when memory profiling runs.
block_table = []
else:
# Prefill without chunked prefill or memory profiling.
block_table = []
block_tables.append(block_table)
seq_lens.append(sliding_seq_len)
context_lens.append(sliding_context_len)
query_len = sliding_seq_len - sliding_context_len
query_lens.append(query_len)
input_tokens.extend(tokens)
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
prompt_adapter_id = seq_group_metadata.prompt_adapter_id
if is_prompt:
assert len(seq_ids) == 1
num_prefills += 1
num_prefill_tokens += len(tokens)
decode_only = False
prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
num_decode_tokens += query_len
decode_seq_lens.append(sliding_seq_len)
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * query_len
lora_prompt_mapping.extend(
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
is not None else 1))
mm_data = seq_group_metadata.multi_modal_data
if mm_data:
# Process multi-modal data
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if prompt_adapter_id > 0 and is_prompt:
prompt_adapter_requests.add(
seq_group_metadata.prompt_adapter_request)
num_tokens = seq_group_metadata.\
prompt_adapter_num_virtual_tokens
pm = [prompt_adapter_id
] * num_tokens + [0] * (query_len - num_tokens)
prompt_adapter_index_mapping += pm
prompt_adapter_prompt_mapping.extend(
[prompt_adapter_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is_profile_run = _is_block_tables_empty(
seq_group_metadata.block_tables)
if is_profile_run:
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
if is_prompt:
assert self.scheduler_config.use_v2_block_manager \
or context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention in V1 block manager")
# It is an optimization. When it is decoding, it is always
# 0. When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - self.sliding_window)
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Prepare input tensors for flashinfer
if self.attn_backend.get_name() == "flashinfer":
seq_len = seq_data.get_len()
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
block_table_bound = seq_len // self.block_size + 1 \
if seq_len % self.block_size != 0 \
else seq_len // self.block_size
paged_kv_indices.extend(block_table[:block_table_bound])
paged_kv_indptr.append(paged_kv_indptr[-1] +
block_table_bound)
last_page_len = seq_len % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
batch_size = len(input_tokens)
max_query_len = max(query_lens)
max_prefill_seq_len = max(prefill_seq_lens, default=0)
max_decode_seq_len = max(decode_seq_lens, default=0)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
use_captured_graph = (
decode_only and not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
for _ in range(graph_batch_size - batch_size):
input_tokens.append(0)
input_positions.append(0)
slot_mapping.append(_PAD_SLOT_ID)
seq_lens.append(1)
block_tables.append([])
lora_index_mapping.append(0)
prompt_adapter_index_mapping.append(0)
if self.attn_backend.get_name() == "flashinfer":
last_paged_kv_indptr = paged_kv_indptr[-1]
paged_kv_indptr.append(last_paged_kv_indptr)
paged_kv_last_page_len.append(0)
batch_size = graph_batch_size
num_decode_tokens = batch_size
if use_captured_graph:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size]
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=self.device)
else:
max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
logits_soft_cap = getattr(self.model_config.hf_config,
'attn_logit_softcapping', None)
if logits_soft_cap is not None and self.attn_backend.get_name(
) != "flashinfer":
raise ValueError("Please use Flashinfer backend for models with"
"logits_soft_cap (i.e., Gemma-2)."
" Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
if self.attn_backend.get_name() == "flashinfer":
if len(paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
device='cpu',
dtype=torch.int)
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
device='cpu',
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
paged_kv_last_page_len, device='cpu', dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
num_qo_heads=self.model_config.get_num_attention_heads(
self.parallel_config),
num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config),
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
seq_start_loc=seq_start_loc,
query_start_loc=query_start_loc,
device=self.device,
data_type=kv_cache_dtype,
use_cuda_graph=use_captured_graph,
logits_soft_cap=logits_soft_cap)
else:
attn_metadata = self.attn_backend.make_metadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
if self.prompt_adapter_config:
prompt_adapter_mapping = PromptAdapterMapping(
prompt_adapter_index_mapping,
prompt_adapter_prompt_mapping,
)
else:
prompt_adapter_mapping = None
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
request_ids_to_seq_ids = {
seq_group_metadata.request_id:
list(seq_group_metadata.seq_data.keys())
for seq_group_metadata in seq_group_metadata_list
}
return self._model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids,
finished_requests_ids=finished_requests_ids,
prompt_adapter_mapping=prompt_adapter_mapping,
prompt_adapter_requests=prompt_adapter_requests,
)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:

View File

@ -113,6 +113,21 @@ class ModelRunnerInputBase(ABC):
raise NotImplementedError raise NotImplementedError
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
raise NotImplementedError
@abstractmethod
def build(self, *args, **kwargs) -> T:
"""Build metadata with on-device tensors."""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]): class ModelRunnerBase(ABC, Generic[T]):
""" """
Model runner interface that abstracts a particular hardware and/or type of Model runner interface that abstracts a particular hardware and/or type of