[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)

This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend.

It also refactors subquery_start_loc which was not refactored in the previous PR
This commit is contained in:
SangBin Cho 2024-05-15 14:00:10 +09:00 committed by GitHub
parent 8a7cc254a0
commit 65bf2ac165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 781 additions and 734 deletions

View File

@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = model_input.slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is True
assert attn_metadata.num_prefills > 0
assert attn_metadata.num_decode_tokens == 0
assert torch.allclose(
attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_seq_len == max(seq_lens)
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
assert attn_metadata.max_decode_seq_len == 0
# Test subquery start locs.
start_idx = 0
@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size):
start_idx += seq_len
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.subquery_start_loc,
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
# equivalent to query_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size):
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
assert input_tokens == input_positions
torch.allclose(input_tokens, input_positions)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False,
)
seq_lens = []
context_lens = []
seq_group_metadata_list = []
# Assume each seq group finishes prefill.
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = list(range(seq_len))
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = list(range(context_len))
seq_data = SequenceData(seq_data)
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.slot_mapping)
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is False
assert attn_metadata.seq_lens is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
assert attn_metadata.max_seq_len == max(seq_lens)
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
seq_lens = [context_len + 1 for context_len in context_lens]
# seq_lens are padded to expected_bs
for _ in range(expected_bs - len(seq_lens)):
seq_lens.append(1)
assert attn_metadata.seq_lens == seq_lens
start_idx = 0
start_loc = [start_idx]
for _ in context_lens:
# decode has only 1 token for query.
start_idx += 1
start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
assert torch.allclose(
attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
assert torch.allclose(
attn_metadata.context_lens_tensor,
torch.tensor(context_lens, dtype=torch.int, device=device))
assert attn_metadata.max_decode_seq_len == max(seq_lens)
assert torch.allclose(
attn_metadata.seq_lens_tensor[:len(seq_lens)],
torch.tensor(seq_lens, dtype=torch.int, device=device))
@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size):
# It is padded up to
assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True
assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
assert input_tokens == input_positions
torch.allclose(input_tokens, input_positions)
# Verify Sampling
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
for _ in context_lens:
expected_selected_token_indices.append(selected_token_start_idx)
selected_token_start_idx += 1
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
query_lens=seq_lens,
# query lens is all 1 for decode.
query_lens=[1 for _ in range(len(context_lens))],
device=model_runner.device,
pin_memory=model_runner.pin_memory)
actual = sampling_metadata.selected_token_indices
@ -220,15 +257,27 @@ def test_empty_seq_group():
enforce_eager=False,
)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, slot_mapping,
return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
model_input.seq_lens,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(seq_len))
context_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(context_len))
seq_data = SequenceData(prompt_toks)
seq_data.append_token_id(1, 0)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_decode_tokens == decode_batch_size
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta = model_runner._prepare_prompt(
prefill_metadata_list).attn_metadata
decode_meta = model_runner._prepare_decode(
decode_metadata_list).attn_metadata
attn_metadata = model_runner._prepare_model_input(
seq_group_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(prefill_meta),
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(decode_meta),
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]

View File

@ -1,6 +1,5 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
@ -8,6 +7,6 @@ __all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataPerStage",
"Attention",
"get_attn_backend",
]

View File

@ -21,7 +21,7 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
raise NotImplementedError
@staticmethod
@ -53,8 +53,34 @@ class AttentionBackend(ABC):
@dataclass
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run prefill
attention."""
pass
@property
@abstractmethod
def decode_metadata(self) -> Optional["AttentionMetadata"]:
"""Return the attention metadata that's required to run decode
attention."""
pass
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
@ -70,40 +96,10 @@ class AttentionMetadataPerStage:
}
T = TypeVar("T", bound=AttentionMetadataPerStage)
T = TypeVar("T", bound=AttentionMetadata)
@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata: Optional[T]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata: Optional[T]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
def __post_init__(self):
if self.num_prefill_tokens > 0:
assert self.num_prefills > 0
assert self.prefill_metadata is not None
if self.num_decode_tokens > 0:
assert self.decode_metadata is not None
class AttentionImpl(ABC):
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
def __init__(
@ -125,7 +121,7 @@ class AttentionImpl(ABC):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_metadata: T,
kv_scale: float = 1.0,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -11,8 +11,7 @@ import torch
from vllm_flash_attn import flash_attn_varlen_func
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
@ -58,8 +57,7 @@ class FlashAttentionBackend(AttentionBackend):
@dataclass
class FlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
@ -67,9 +65,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
@ -84,14 +79,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
@ -105,6 +104,70 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None
self._cached_prefill_metadata = FlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = FlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class FlashAttentionImpl(AttentionImpl):
"""
@ -168,7 +231,7 @@ class FlashAttentionImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
attn_metadata: FlashAttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -228,8 +291,8 @@ class FlashAttentionImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
@ -249,7 +312,7 @@ class FlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
@ -264,7 +327,7 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
decode_meta.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,

View File

@ -8,8 +8,7 @@ from vllm_flash_attn import flash_attn_varlen_func
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
class FlashInferBackend(AttentionBackend):
@ -56,9 +55,10 @@ class FlashInferBackend(AttentionBackend):
@dataclass
class FlashInferMetadata(AttentionMetadataPerStage):
is_prompt: bool
class FlashInferMetadata(AttentionMetadata):
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
use_cuda_graph: bool = False
@ -67,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage):
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc: Optional[torch.Tensor] = None
max_seq_len: Optional[int] = None
block_tables: Optional[torch.Tensor] = None
# Metadata for the decode stage
@ -113,7 +112,8 @@ class FlashInferMetadata(AttentionMetadataPerStage):
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if not self.is_prompt:
if self.num_prefills == 0:
assert self.num_decode_tokens > 0
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, "NHD")
self.decode_wrapper.begin_forward(
@ -138,6 +138,24 @@ class FlashInferMetadata(AttentionMetadataPerStage):
skip_fields.add('decode_wrapper')
return super().asdict_zerocopy(skip_fields)
@property
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self
return None
@property
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
return None
return self
class FlashInferImpl(AttentionImpl):
@ -172,7 +190,7 @@ class FlashInferImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[FlashInferMetadata],
attn_metadata: FlashInferMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
assert kv_scale == 1.0
@ -208,8 +226,8 @@ class FlashInferImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,

View File

@ -6,8 +6,7 @@ import torch
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@ -56,8 +55,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
@ -65,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
@ -82,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
query_start_loc: Optional[torch.Tensor]
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
@ -102,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
@property
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
assert self.seq_start_loc is not None
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = ROCmFlashAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class ROCmFlashAttentionImpl(AttentionImpl):
@ -198,7 +260,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
attn_metadata: ROCmFlashAttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -266,8 +328,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
None,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_seq_len,
prefill_meta.max_seq_len,
prefill_meta.max_prefill_seq_len,
prefill_meta.max_prefill_seq_len,
True,
self.scale,
)
@ -290,8 +352,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_seq_len,
max_seqlen_k=prefill_meta.max_seq_len,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
@ -308,7 +370,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
@ -324,7 +386,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
decode_meta.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,

View File

@ -7,8 +7,7 @@ import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
@ -54,8 +53,7 @@ class TorchSDPABackend(AttentionBackend):
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
AttentionMetadataPerStage):
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
@ -72,8 +70,26 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[torch.Tensor]] = None
@property
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_decode_tokens == 0:
assert self.num_prefills > 0
return self
class TorchSDPABackendImpl(AttentionImpl):
return None
@property
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
return None
return self
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def __init__(
self,
@ -200,7 +216,7 @@ class TorchSDPABackendImpl(AttentionImpl):
value_cache,
attn_metadata.block_tables,
attn_metadata.seq_lens_tensor,
attn_metadata.max_seq_len,
attn_metadata.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,

View File

@ -9,8 +9,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataPerStage)
AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@ -59,7 +58,7 @@ class XFormersBackend(AttentionBackend):
@dataclass
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
@ -67,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
@ -83,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch.
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
query_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
@ -105,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
_cached_decode_metadata: Optional["XFormersMetadata"] = None
def __post_init__(self):
# Set during the execution of the first attention op.
@ -114,8 +116,68 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
# will not appear in the __repr__ and __init__
self.attn_bias: Optional[List[AttentionBias]] = None
@property
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
if self.num_prefills == 0:
return None
class XFormersImpl(AttentionImpl):
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert self.seq_lens is not None
assert self.seq_lens_tensor is not None
assert self.query_start_loc is not None
assert self.context_lens_tensor is not None
assert self.block_tables is not None
self._cached_prefill_metadata = XFormersMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=None,
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["XFormersMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert self.block_tables is not None
assert self.seq_lens_tensor is not None
self._cached_decode_metadata = XFormersMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
return self._cached_decode_metadata
class XFormersImpl(AttentionImpl[XFormersMetadata]):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
@ -176,7 +238,7 @@ class XFormersImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[XFormersMetadata],
attn_metadata: "XFormersMetadata",
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -244,7 +306,7 @@ class XFormersImpl(AttentionImpl):
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len,
@ -261,7 +323,7 @@ class XFormersImpl(AttentionImpl):
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_seq_len,
decode_meta.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,

View File

@ -4,8 +4,7 @@ from typing import List, Optional
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
@ -57,7 +56,7 @@ class Attention(nn.Module):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
attn_metadata: AttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,

View File

@ -16,8 +16,8 @@ class PagedAttentionMetadata:
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor: Optional[torch.Tensor]
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len: int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
@ -166,7 +166,7 @@ class PagedAttention:
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int,
@ -182,8 +182,8 @@ class PagedAttention:
key_cache,
value_cache,
block_tables,
# subquery_start_loc is (batch_size + 1,)
subquery_start_loc[:-1],
# query_start_loc is (batch_size + 1,)
query_start_loc[:-1],
seq_lens_tensor,
context_lens,
max_query_len,

View File

@ -618,6 +618,11 @@ class EngineArgs:
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled):
raise ValueError(
"Chunked prefill is not supported with sliding window.")
return EngineConfig(model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,

View File

@ -122,6 +122,7 @@ class RejectionSampler(nn.Module):
draft_token_ids,
bonus_token_ids,
)
return output_token_ids
def _batch_modified_rejection_sampling(

View File

@ -654,8 +654,9 @@ class SequenceGroupMetadata:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def token_chunk_size(self) -> Optional[int]:
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""
assert self._token_chunk_size is not None
return self._token_chunk_size

View File

@ -293,21 +293,30 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
prompt_token_ids = seq_data.get_prompt_token_ids()
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
new_seq_data_dict = {
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
),
}
# This is a hack. Technically, spec decoding should compute
# num_lookahead slots at one shot, but instead, it expands the batch
# and evaluate one by one right now. context_len is seq_len - 1 because
# the kv cache is filled by a previous batch in the batch expansion.
for data in new_seq_data_dict.values():
data.update_num_computed_tokens(data.get_len() - 1)
return SequenceGroupMetadata(
request_id=seq_group_metadata.request_id,
is_prompt=seq_group_metadata.is_prompt,
seq_data={
target_seq_id:
SequenceData(
prompt_token_ids=prompt_token_ids,
output_token_ids=new_output_token_ids,
),
},
seq_data=new_seq_data_dict,
sampling_params=seq_group_metadata.sampling_params,
block_tables={
target_seq_id: seq_group_metadata.block_tables[seq_id],
},
lora_request=None,
token_chunk_size=1,
)
def _split_scoring_output(

View File

@ -114,6 +114,7 @@ class MultiStepWorker(Worker):
token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)
def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata]

View File

@ -159,12 +159,10 @@ class CPUModelRunner:
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=None,
max_seq_len=None,
max_decode_seq_len=None,
num_prefills=len(seq_lens),
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
prefill_metadata=None,
decode_metadata=None,
block_tables=torch.tensor([]),
slot_mapping=slot_mapping,
)
@ -213,7 +211,7 @@ class CPUModelRunner:
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
max_seq_len = max(seq_lens)
max_decode_seq_len = max(seq_lens)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
@ -243,12 +241,10 @@ class CPUModelRunner:
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_seq_len=max_seq_len,
max_decode_seq_len=max_decode_seq_len,
num_prefill_tokens=0,
num_decode_tokens=len(input_tokens),
num_prefills=0,
prefill_metadata=None,
decode_metadata=None,
block_tables=block_tables,
)
return (

View File

@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import BatchType, ModelRunner
from vllm.worker.model_runner import ModelRunner
logger = init_logger(__name__)
@ -88,85 +88,24 @@ class EmbeddingModelRunner(ModelRunner):
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors.
(
input_tokens,
input_positions,
prefill_attn_metadata,
prompt_lens,
subquery_lens,
lora_index_mapping,
lora_prompt_mapping,
attn_metadata,
seq_lens,
_,
lora_mapping,
lora_requests,
multi_modal_input,
slot_mapping,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
# Prepare PoolingMetadata
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
prompt_lens)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
seq_lens)
metadata_dict = {
"input_tokens": input_tokens,
@ -178,65 +117,26 @@ class EmbeddingModelRunner(ModelRunner):
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
}
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = None
pooling_metadata = PoolingMetadata(seq_groups=None,
seq_data=None,
prompt_lens=None)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
)
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input)

View File

@ -1,13 +1,11 @@
import time
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
@ -37,66 +35,38 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
]
class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
class ModelInput(NamedTuple):
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: Optional[AttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_mapping: Optional[LoRAMapping]
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int]
slot_mapping: torch.Tensor
num_prefill_tokens: int
num_decode_tokens: int
num_prefills: int
@classmethod
def empty(cls):
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
def empty(cls, device):
return ModelInput(
input_tokens=torch.empty(0, device=device),
input_positions=torch.empty(0, device=device),
attn_metadata=None,
seq_lens=[],
query_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_mapping=None,
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
slot_mapping=torch.empty(0, device=device),
num_prefill_tokens=0,
num_decode_tokens=0,
num_prefills=0,
)
class PrepareDecodeMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner:
def __init__(
@ -216,10 +186,22 @@ class ModelRunner:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_prompt(
def _prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> PreparePromptMetadata:
) -> ModelInput:
"""Prepare the model input based on a given sequence group.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
If cuda graph is required, this API automatically pads inputs.
"""
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
@ -228,212 +210,16 @@ class ModelRunner:
lora_requests: Set[LoRARequest] = set()
seq_lens: List[int] = []
prefill_seq_lens: List[int] = []
decode_seq_lens: List[int] = []
context_lens: List[int] = []
query_lens: List[int] = []
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
token_chunk_size = seq_group_metadata.token_chunk_size
seq_data = seq_group_metadata.seq_data[seq_id]
context_len = seq_data.get_num_computed_tokens()
# We should use get_len here because in case of preemption
# it contains output tokens.
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
seq_lens.append(seq_len)
# NOTE: This only works for oooooooxxx style attention.
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
context_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[context_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
# assumption can be changed once chunked prefill is introduced.
assert context_len == 0
# actual prompt lens
context_lens.append(context_len)
query_lens.append(seq_len - context_len)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (seq_len - context_len)
lora_prompt_mapping.extend([lora_id] * (
seq_len - context_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
if _is_block_tables_empty(seq_group_metadata.block_tables):
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, seq_len - self.sliding_window)
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
max_query_len = max(query_lens)
max_seq_len = max(seq_lens)
assert max_query_len > 0
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
# Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = make_tensor_with_pad(
prefix_block_tables,
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
# Query length can be shorter than key (i.e., prompt) when prefill
# is chunked or prefix cached.
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=subquery_start_loc.dtype,
out=subquery_start_loc[1:])
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
if self.attn_backend.get_name() == "flashinfer":
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
use_cuda_graph=False,
seq_start_loc=seq_start_loc,
max_seq_len=max_seq_len,
block_tables=block_tables)
else:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
)
return PreparePromptMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
seq_lens=seq_lens,
query_lens=query_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> PrepareDecodeMetadata:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
multi_modal_input_list: List[torch.Tensor] = []
decode_only = True
num_prefills = 0
num_prefill_tokens = 0
num_decode_tokens = 0
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
@ -454,60 +240,186 @@ class ModelRunner:
paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
return ModelInput.empty(self.device)
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
seq_ids = list(seq_group_metadata.seq_data.keys())
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
is_prompt = seq_group_metadata.is_prompt
for seq_id in seq_ids:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
else:
# get_num_computed_tokens is incorrect for spec decoding.
# So, we should have a special logic here.
# TODO(sang): Fix it.
context_len = seq_data.get_len() - 1
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
seq_len = min(
seq_data.get_len(),
context_len + seq_group_metadata.token_chunk_size)
if is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
seq_lens.append(seq_len)
# Prefix cache was hit.
# Prefix is not supported with sliding_window
prefix_cache_hit = (computed_block_nums is not None
and len(computed_block_nums) > 0
and self.sliding_window is None
and is_prompt)
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
lora_index_mapping.append(lora_id)
lora_prompt_mapping.append(lora_id)
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
tokens = tokens[context_len:]
if self.attn_backend.get_name() == "flash-attn":
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# TODO(woosuk): This is a temporary fix. We should
# provide a unified interface for different backends.
block_table = seq_group_metadata.block_tables[seq_id]
else:
block_table = computed_block_nums
elif (self.scheduler_config.chunked_prefill_enabled
or not is_prompt):
if seq_group_metadata.block_tables is not None:
# chunked prefill or decode
block_table = seq_group_metadata.block_tables[seq_id]
if self.sliding_window is not None:
# chunked prefill doesn't support sliding window.
assert (not self.scheduler_config.
chunked_prefill_enabled)
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
if self.attn_backend.get_name() == "flashinfer":
paged_kv_indices.extend(block_table)
paged_kv_indptr.append(paged_kv_indptr[-1] +
len(block_table))
last_page_len = seq_data.get_len(
) % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
else:
# Only happens when memory profiling runs.
block_table = []
else:
# Prefill without chunked prefill or memory profiling.
block_table = []
block_tables.append(block_table)
paged_kv_indices.extend(block_table)
paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
last_page_len = seq_data.get_len() % self.block_size
if last_page_len == 0:
last_page_len = self.block_size
paged_kv_last_page_len.append(last_page_len)
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if (self.sliding_window is not None and not is_prompt):
seq_len = min(seq_len, self.sliding_window)
context_len = seq_len - 1
seq_lens.append(seq_len)
context_lens.append(context_len)
query_len = seq_len - context_len
query_lens.append(query_len)
input_tokens.extend(tokens)
input_positions.extend(list(range(context_len, seq_len)))
lora_id = seq_group_metadata.lora_int_id
if is_prompt:
assert len(seq_ids) == 1
num_prefills += 1
num_prefill_tokens += len(tokens)
decode_only = False
prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
num_decode_tokens += query_len
decode_seq_lens.append(seq_len)
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping += [lora_id] * (seq_len - context_len)
lora_prompt_mapping.extend(
[lora_id] *
(seq_len -
context_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
if seq_group_metadata.multi_modal_data:
multi_modal_input_list.append(
seq_group_metadata.multi_modal_data.data)
if _is_block_tables_empty(seq_group_metadata.block_tables):
# During memory profiling, the block tables are not
# initialized yet. In this case, we just use a dummy
# slot mapping.
# In embeddings, the block tables are {seq_id: None}.
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
# sliding_window). For example, if the prompt len is 10,
# sliding window is 8, and block size is 4, the first two
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
if is_prompt:
assert context_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
# It is an optimization. When it is decoding, it is always
# 0. When prefill, we use it to not write slots to kv cache
# to save memory.
start_idx = max(0, query_len - self.sliding_window)
for i in range(context_len, seq_len):
if i < start_idx:
slot_mapping.append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens)
max_seq_len = max(seq_lens)
use_captured_graph = (not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_seq_len <= self.max_seq_len_to_capture)
max_query_len = max(query_lens)
max_prefill_seq_len = max(prefill_seq_lens, default=0)
max_decode_seq_len = max(decode_seq_lens, default=0)
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
use_captured_graph = (
decode_only and not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.max_seq_len_to_capture)
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
@ -519,18 +431,9 @@ class ModelRunner:
block_tables.append([])
lora_index_mapping.append(0)
batch_size = graph_batch_size
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
num_decode_tokens = batch_size
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert seq_lens_tensor.shape[0] == len(input_tokens)
assert seq_lens_tensor.shape[0] == len(input_positions)
assert seq_lens_tensor.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size]
@ -548,6 +451,57 @@ class ModelRunner:
dtype=torch.int,
device=self.device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
if multi_modal_input_list:
assert self.vision_language_config, (
"Multi-modal inputs are only supported by "
"vision language models.")
multi_modal_input = torch.cat(multi_modal_input_list,
dim=0).to(self.device)
else:
multi_modal_input = None
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=self.device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
input_tokens_tensor = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions_tensor = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping_tensor = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"):
@ -555,53 +509,75 @@ class ModelRunner:
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
self.flashinfer_workspace_buffer = torch.empty(
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
paged_kv_indptr = torch.tensor(paged_kv_indptr,
dtype=torch.int,
device=self.device)
paged_kv_indices = torch.tensor(paged_kv_indices,
dtype=torch.int,
device=self.device)
paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
dtype=torch.int,
device=self.device)
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
dtype=torch.int,
device=self.device)
paged_kv_last_page_len_tensor = torch.tensor(
paged_kv_last_page_len, dtype=torch.int, device=self.device)
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype)
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
use_cuda_graph=False,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
workspace_buffer=self.flashinfer_workspace_buffer,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
paged_kv_indptr=paged_kv_indptr_tensor,
paged_kv_indices=paged_kv_indices_tensor,
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
num_qo_heads=self.model_config.get_num_attention_heads(
self.parallel_config),
num_kv_heads=self.model_config.get_num_kv_heads(
self.parallel_config),
head_dim=self.model_config.get_head_size(),
page_size=self.block_size,
page_size=16,
seq_start_loc=seq_start_loc,
data_type=kv_cache_dtype)
else:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
seq_lens=None,
num_prefills=num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)
return PrepareDecodeMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
return ModelInput(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
seq_lens=seq_lens,
query_lens=query_lens,
lora_mapping=lora_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
)
def prepare_input_tensors(
@ -610,85 +586,25 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[LoRARequest], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors.
(
input_tokens,
input_positions,
prefill_attn_metadata,
attn_metadata,
seq_lens,
query_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_mapping,
lora_requests,
multi_modal_input,
slot_mapping,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(seq_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
lora_prompt_mapping,
)
else:
lora_mapping = None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
@ -701,46 +617,24 @@ class ModelRunner:
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
}
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
assert decode_attn_metadata is not None
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = None
sampling_metadata = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
@ -748,22 +642,6 @@ class ModelRunner:
num_prompts=0,
)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_input)
@ -954,25 +832,21 @@ class ModelRunner:
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata.
decode_metadata = self.attn_backend.make_metadata(
is_prompt=False,
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_seq_len=self.max_seq_len_to_capture,
subquery_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
attn_metadata = AttentionMetadata(
attn_metadata = self.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
prefill_metadata=None,
decode_metadata=decode_metadata,
seq_lens=None,
seq_lens_tensor=seq_lens[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
if self.lora_config: