mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-05 02:30:55 +08:00
[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)
This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend. It also refactors subquery_start_loc which was not refactored in the previous PR
This commit is contained in:
parent
8a7cc254a0
commit
65bf2ac165
@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size):
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
seq_len - 1)
|
||||
selected_token_start_idx += seq_len
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = model_input.slot_mapping
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is True
|
||||
assert attn_metadata.num_prefills > 0
|
||||
assert attn_metadata.num_decode_tokens == 0
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_lens_tensor,
|
||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
||||
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
|
||||
assert attn_metadata.max_decode_seq_len == 0
|
||||
|
||||
# Test subquery start locs.
|
||||
start_idx = 0
|
||||
@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size):
|
||||
start_idx += seq_len
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
# Test seq start locs. Note that for normal prefill it is
|
||||
# equivalent to subquery_start_loc.
|
||||
# equivalent to query_start_loc.
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size):
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
assert input_tokens == input_positions
|
||||
torch.allclose(input_tokens, input_positions)
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
seq_lens = []
|
||||
context_lens = []
|
||||
seq_group_metadata_list = []
|
||||
# Assume each seq group finishes prefill.
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = list(range(seq_len))
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
context_lens.append(context_len)
|
||||
seq_data = list(range(context_len))
|
||||
seq_data = SequenceData(seq_data)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
# Append one token ID since prefill is finished.
|
||||
seq_data.append_token_id(1, 0)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input.input_tokens, model_input.input_positions,
|
||||
model_input.attn_metadata, model_input.slot_mapping)
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is False
|
||||
assert attn_metadata.seq_lens is None
|
||||
assert attn_metadata.subquery_start_loc is None
|
||||
assert attn_metadata.seq_start_loc is None
|
||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_prefill_tokens == 0
|
||||
seq_lens = [context_len + 1 for context_len in context_lens]
|
||||
# seq_lens are padded to expected_bs
|
||||
for _ in range(expected_bs - len(seq_lens)):
|
||||
seq_lens.append(1)
|
||||
assert attn_metadata.seq_lens == seq_lens
|
||||
start_idx = 0
|
||||
start_loc = [start_idx]
|
||||
for _ in context_lens:
|
||||
# decode has only 1 token for query.
|
||||
start_idx += 1
|
||||
start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
attn_metadata.query_start_loc,
|
||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
start_idx = 0
|
||||
seq_start_loc = [start_idx]
|
||||
for seq_len in seq_lens:
|
||||
start_idx += seq_len
|
||||
seq_start_loc.append(start_idx)
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_start_loc,
|
||||
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
|
||||
|
||||
assert torch.allclose(
|
||||
attn_metadata.context_lens_tensor,
|
||||
torch.tensor(context_lens, dtype=torch.int, device=device))
|
||||
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||
assert torch.allclose(
|
||||
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
||||
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
||||
@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# It is padded up to
|
||||
assert attn_metadata.block_tables.shape[1] == (
|
||||
model_runner.get_max_block_per_batch())
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is True
|
||||
|
||||
assert len(input_tokens) == expected_bs
|
||||
assert len(input_positions) == expected_bs
|
||||
assert input_tokens == input_positions
|
||||
torch.allclose(input_tokens, input_positions)
|
||||
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
selected_token_start_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
for _ in context_lens:
|
||||
expected_selected_token_indices.append(selected_token_start_idx)
|
||||
selected_token_start_idx += 1
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
query_lens=seq_lens,
|
||||
# query lens is all 1 for decode.
|
||||
query_lens=[1 for _ in range(len(context_lens))],
|
||||
device=model_runner.device,
|
||||
pin_memory=model_runner.pin_memory)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
@ -220,15 +257,27 @@ def test_empty_seq_group():
|
||||
enforce_eager=False,
|
||||
)
|
||||
seq_group_metadata_list = []
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.slot_mapping,
|
||||
)
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, slot_mapping,
|
||||
return_seq_lens) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.slot_mapping,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
# Add decode requests
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(seq_len))
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(context_len))
|
||||
seq_data = SequenceData(prompt_toks)
|
||||
seq_data.append_token_id(1, 0)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
||||
assert len(input_positions) == len(input_tokens)
|
||||
assert attn_metadata.num_prefills == prefill_batch_size
|
||||
if enforce_eager:
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
else:
|
||||
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
|
||||
decode_batch_size)
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
prefill_meta = model_runner._prepare_prompt(
|
||||
prefill_metadata_list).attn_metadata
|
||||
decode_meta = model_runner._prepare_decode(
|
||||
decode_metadata_list).attn_metadata
|
||||
attn_metadata = model_runner._prepare_model_input(
|
||||
seq_group_metadata_list).attn_metadata
|
||||
|
||||
for attr_expected, attr_actual in zip(vars(prefill_meta),
|
||||
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
|
||||
vars(prefill_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
for attr_expected, attr_actual in zip(vars(decode_meta),
|
||||
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
|
||||
vars(decode_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
AttentionMetadata)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
@ -8,6 +7,6 @@ __all__ = [
|
||||
"Attention",
|
||||
"AttentionBackend",
|
||||
"AttentionMetadata",
|
||||
"AttentionMetadataPerStage",
|
||||
"Attention",
|
||||
"get_attn_backend",
|
||||
]
|
||||
|
||||
@ -21,7 +21,7 @@ class AttentionBackend(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
|
||||
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@ -53,8 +53,34 @@ class AttentionBackend(ABC):
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadataPerStage:
|
||||
"""Attention metadata for a specific stage. I.e., prefill or decode."""
|
||||
class AttentionMetadata:
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# Number of decode tokens. Note that it is equivalent to the number of
|
||||
# decode requests.
|
||||
num_decode_tokens: int
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
|
||||
"""Return the attention metadata that's required to run prefill
|
||||
attention."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def decode_metadata(self) -> Optional["AttentionMetadata"]:
|
||||
"""Return the attention metadata that's required to run decode
|
||||
attention."""
|
||||
pass
|
||||
|
||||
def asdict_zerocopy(self,
|
||||
skip_fields: Optional[Set[str]] = None
|
||||
@ -70,40 +96,10 @@ class AttentionMetadataPerStage:
|
||||
}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadataPerStage)
|
||||
T = TypeVar("T", bound=AttentionMetadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata(Generic[T]):
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# Number of decode tokens. Note that it is equivalent to the number of
|
||||
# decode requests.
|
||||
num_decode_tokens: int
|
||||
# The attention metadata for prefill requests in a batch.
|
||||
# None if there's no prefill requests in a batch.
|
||||
prefill_metadata: Optional[T]
|
||||
# The attention metadata for decode requests in a batch.
|
||||
# None if there's no decode requests in a batch.
|
||||
decode_metadata: Optional[T]
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_prefill_tokens > 0:
|
||||
assert self.num_prefills > 0
|
||||
assert self.prefill_metadata is not None
|
||||
if self.num_decode_tokens > 0:
|
||||
assert self.decode_metadata is not None
|
||||
|
||||
|
||||
class AttentionImpl(ABC):
|
||||
class AttentionImpl(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
@ -125,7 +121,7 @@ class AttentionImpl(ABC):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_metadata: T,
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -11,8 +11,7 @@ import torch
|
||||
from vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
AttentionMetadata)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
@ -58,8 +57,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
PagedAttentionMetadata):
|
||||
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -67,9 +65,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
@ -84,14 +79,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch.
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
subquery_start_loc: Optional[torch.Tensor]
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
@ -105,6 +104,70 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
|
||||
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = FlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = FlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
@ -168,7 +231,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -228,8 +291,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
@ -249,7 +312,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
@ -264,7 +327,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_seq_len,
|
||||
decode_meta.max_decode_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
|
||||
@ -8,8 +8,7 @@ from vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
AttentionMetadata)
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
@ -56,9 +55,10 @@ class FlashInferBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashInferMetadata(AttentionMetadataPerStage):
|
||||
|
||||
is_prompt: bool
|
||||
class FlashInferMetadata(AttentionMetadata):
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
|
||||
use_cuda_graph: bool = False
|
||||
|
||||
@ -67,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage):
|
||||
# Metadata for the prefill stage since we still
|
||||
# use flash attention for prefill.
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
max_seq_len: Optional[int] = None
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
|
||||
# Metadata for the decode stage
|
||||
@ -113,7 +112,8 @@ class FlashInferMetadata(AttentionMetadataPerStage):
|
||||
# When using flashinfer, we are also creating the FlashInferMetadata,
|
||||
# which will also call post_init by default, here we want to skip the
|
||||
# post_init if it's the prefill phase.
|
||||
if not self.is_prompt:
|
||||
if self.num_prefills == 0:
|
||||
assert self.num_decode_tokens > 0
|
||||
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD")
|
||||
self.decode_wrapper.begin_forward(
|
||||
@ -138,6 +138,24 @@ class FlashInferMetadata(AttentionMetadataPerStage):
|
||||
skip_fields.add('decode_wrapper')
|
||||
return super().asdict_zerocopy(skip_fields)
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
assert self.num_prefills > 0
|
||||
return self
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefills > 0:
|
||||
assert self.num_decode_tokens == 0
|
||||
return None
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
|
||||
@ -172,7 +190,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata[FlashInferMetadata],
|
||||
attn_metadata: FlashInferMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
assert kv_scale == 1.0
|
||||
@ -208,8 +226,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
|
||||
@ -6,8 +6,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
AttentionMetadata)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -56,8 +55,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
PagedAttentionMetadata):
|
||||
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -65,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
@ -82,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch.
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
subquery_start_loc: Optional[torch.Tensor]
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
@ -102,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_start_loc is not None
|
||||
|
||||
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = ROCmFlashAttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
@ -198,7 +260,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -266,8 +328,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
None,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_seq_len,
|
||||
prefill_meta.max_seq_len,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
prefill_meta.max_prefill_seq_len,
|
||||
True,
|
||||
self.scale,
|
||||
)
|
||||
@ -290,8 +352,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
v=value,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
@ -308,7 +370,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
@ -324,7 +386,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_seq_len,
|
||||
decode_meta.max_decode_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
|
||||
@ -7,8 +7,7 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
AttentionMetadata)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
@ -54,8 +53,7 @@ class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
|
||||
AttentionMetadataPerStage):
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for TorchSDPABackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
@ -72,8 +70,26 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
assert self.num_prefills > 0
|
||||
return self
|
||||
|
||||
class TorchSDPABackendImpl(AttentionImpl):
|
||||
return None
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefills > 0:
|
||||
assert self.num_decode_tokens == 0
|
||||
return None
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -200,7 +216,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens_tensor,
|
||||
attn_metadata.max_seq_len,
|
||||
attn_metadata.max_decode_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
|
||||
@ -9,8 +9,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
AttentionMetadata)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -59,7 +58,7 @@ class XFormersBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for XFormersbackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -67,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
@ -83,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
# |-------------------- seq_len ----------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum query length in the batch.
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int]
|
||||
# FIXME: It is for flash attn.
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
subquery_start_loc: Optional[torch.Tensor]
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
# FIXME: It is for flash attn.
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
@ -105,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
use_cuda_graph: bool
|
||||
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
||||
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
@ -114,8 +116,68 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
class XFormersImpl(AttentionImpl):
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.query_start_loc is not None
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.block_tables is not None
|
||||
|
||||
self._cached_prefill_metadata = XFormersMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||
seq_lens=self.seq_lens[:self.num_prefills],
|
||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||
block_tables=self.block_tables[:self.num_prefills],
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["XFormersMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert self.block_tables is not None
|
||||
assert self.seq_lens_tensor is not None
|
||||
|
||||
self._cached_decode_metadata = XFormersMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self.block_tables[self.num_prefills:],
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
@ -176,7 +238,7 @@ class XFormersImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata[XFormersMetadata],
|
||||
attn_metadata: "XFormersMetadata",
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
@ -244,7 +306,7 @@ class XFormersImpl(AttentionImpl):
|
||||
key_cache,
|
||||
value_cache,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.query_start_loc,
|
||||
prefill_meta.seq_lens_tensor,
|
||||
prefill_meta.context_lens_tensor,
|
||||
prefill_meta.max_query_len,
|
||||
@ -261,7 +323,7 @@ class XFormersImpl(AttentionImpl):
|
||||
value_cache,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor,
|
||||
decode_meta.max_seq_len,
|
||||
decode_meta.max_decode_seq_len,
|
||||
self.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
|
||||
@ -4,8 +4,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
|
||||
@ -57,7 +56,7 @@ class Attention(nn.Module):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
|
||||
attn_metadata: AttentionMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||
|
||||
@ -16,8 +16,8 @@ class PagedAttentionMetadata:
|
||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||
# sequence.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
# Maximum sequence length in the batch.
|
||||
max_seq_len: Optional[int]
|
||||
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
@ -166,7 +166,7 @@ class PagedAttention:
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
subquery_start_loc: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens_tensor: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
max_query_len: int,
|
||||
@ -182,8 +182,8 @@ class PagedAttention:
|
||||
key_cache,
|
||||
value_cache,
|
||||
block_tables,
|
||||
# subquery_start_loc is (batch_size + 1,)
|
||||
subquery_start_loc[:-1],
|
||||
# query_start_loc is (batch_size + 1,)
|
||||
query_start_loc[:-1],
|
||||
seq_lens_tensor,
|
||||
context_lens,
|
||||
max_query_len,
|
||||
|
||||
@ -618,6 +618,11 @@ class EngineArgs:
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
|
||||
if (model_config.get_sliding_window() is not None
|
||||
and scheduler_config.chunked_prefill_enabled):
|
||||
raise ValueError(
|
||||
"Chunked prefill is not supported with sliding window.")
|
||||
|
||||
return EngineConfig(model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
|
||||
@ -122,6 +122,7 @@ class RejectionSampler(nn.Module):
|
||||
draft_token_ids,
|
||||
bonus_token_ids,
|
||||
)
|
||||
|
||||
return output_token_ids
|
||||
|
||||
def _batch_modified_rejection_sampling(
|
||||
|
||||
@ -654,8 +654,9 @@ class SequenceGroupMetadata:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def token_chunk_size(self) -> Optional[int]:
|
||||
def token_chunk_size(self) -> int:
|
||||
"""Return the number of tokens to be processed (chunk size)."""
|
||||
assert self._token_chunk_size is not None
|
||||
return self._token_chunk_size
|
||||
|
||||
|
||||
|
||||
@ -293,21 +293,30 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||
|
||||
new_seq_data_dict = {
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
),
|
||||
}
|
||||
# This is a hack. Technically, spec decoding should compute
|
||||
# num_lookahead slots at one shot, but instead, it expands the batch
|
||||
# and evaluate one by one right now. context_len is seq_len - 1 because
|
||||
# the kv cache is filled by a previous batch in the batch expansion.
|
||||
for data in new_seq_data_dict.values():
|
||||
data.update_num_computed_tokens(data.get_len() - 1)
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
seq_data={
|
||||
target_seq_id:
|
||||
SequenceData(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
output_token_ids=new_output_token_ids,
|
||||
),
|
||||
},
|
||||
seq_data=new_seq_data_dict,
|
||||
sampling_params=seq_group_metadata.sampling_params,
|
||||
block_tables={
|
||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
)
|
||||
|
||||
def _split_scoring_output(
|
||||
|
||||
@ -114,6 +114,7 @@ class MultiStepWorker(Worker):
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
def _shallow_copy_inputs(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
|
||||
@ -159,12 +159,10 @@ class CPUModelRunner:
|
||||
is_prompt=True,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=None,
|
||||
max_seq_len=None,
|
||||
max_decode_seq_len=None,
|
||||
num_prefills=len(seq_lens),
|
||||
num_prefill_tokens=num_prompt_tokens,
|
||||
num_decode_tokens=0,
|
||||
prefill_metadata=None,
|
||||
decode_metadata=None,
|
||||
block_tables=torch.tensor([]),
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
@ -213,7 +211,7 @@ class CPUModelRunner:
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
max_seq_len = max(seq_lens)
|
||||
max_decode_seq_len = max(seq_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
@ -243,12 +241,10 @@ class CPUModelRunner:
|
||||
slot_mapping=slot_mapping,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_seq_len=max_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=len(input_tokens),
|
||||
num_prefills=0,
|
||||
prefill_metadata=None,
|
||||
decode_metadata=None,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
return (
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import BatchType, ModelRunner
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -88,85 +88,24 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||
if self.is_driver_worker:
|
||||
prefill_reqs = []
|
||||
decode_reqs = []
|
||||
for seq_group_meta in seq_group_metadata_list:
|
||||
if seq_group_meta.is_prompt:
|
||||
prefill_reqs.append(seq_group_meta)
|
||||
else:
|
||||
decode_reqs.append(seq_group_meta)
|
||||
|
||||
# Prepare input tensors.
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
prefill_attn_metadata,
|
||||
prompt_lens,
|
||||
subquery_lens,
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
attn_metadata,
|
||||
seq_lens,
|
||||
_,
|
||||
lora_mapping,
|
||||
lora_requests,
|
||||
multi_modal_input,
|
||||
slot_mapping,
|
||||
) = self._prepare_prompt(prefill_reqs)
|
||||
(
|
||||
decode_input_tokens,
|
||||
decode_input_positions,
|
||||
decode_attn_metadata,
|
||||
decode_lora_index_mapping,
|
||||
decode_lora_prompt_mapping,
|
||||
decode_lora_requests,
|
||||
decode_slot_mapping,
|
||||
) = self._prepare_decode(decode_reqs)
|
||||
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
) = self._prepare_model_input(seq_group_metadata_list)
|
||||
# Prepare PoolingMetadata
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
prompt_lens)
|
||||
|
||||
if not self.scheduler_config.chunked_prefill_enabled:
|
||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||
|
||||
num_prefills = len(prompt_lens)
|
||||
num_prefill_tokens = len(input_tokens)
|
||||
num_decode_tokens = len(decode_input_tokens)
|
||||
|
||||
# Coalesce tensors. Note that attn_metadata is currently not
|
||||
# coalesced for simplicity.
|
||||
input_tokens.extend(decode_input_tokens)
|
||||
input_positions.extend(decode_input_positions)
|
||||
slot_mapping.extend(decode_slot_mapping)
|
||||
lora_index_mapping.extend(decode_lora_index_mapping)
|
||||
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
||||
lora_requests.update(decode_lora_requests)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
)
|
||||
else:
|
||||
lora_mapping = None
|
||||
|
||||
# Broadcast the metadata.
|
||||
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
||||
# If it only contains 1 type, it triggers a single broadcast.
|
||||
if (prefill_attn_metadata is not None
|
||||
and decode_attn_metadata is not None):
|
||||
batch_type = BatchType.MIXED
|
||||
elif prefill_attn_metadata is not None:
|
||||
batch_type = BatchType.PREFILL
|
||||
else:
|
||||
batch_type = BatchType.DECODE
|
||||
seq_lens)
|
||||
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
@ -178,65 +117,26 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
"num_decode_tokens": num_decode_tokens,
|
||||
"slot_mapping": slot_mapping,
|
||||
"num_prefills": num_prefills,
|
||||
"batch_type": batch_type,
|
||||
}
|
||||
if prefill_attn_metadata is not None:
|
||||
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
||||
else:
|
||||
assert decode_attn_metadata is not None
|
||||
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
||||
if attn_metadata:
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
|
||||
# Broadcast decode attn metadata for mixed batch type.
|
||||
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
||||
# We can potentially reduce the overhead by coelescing tensors.
|
||||
if batch_type == BatchType.MIXED:
|
||||
assert decode_attn_metadata is not None
|
||||
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
slot_mapping = metadata_dict.pop("slot_mapping")
|
||||
num_prefills = metadata_dict.pop("num_prefills")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
||||
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
||||
batch_type = metadata_dict.pop("batch_type")
|
||||
|
||||
# Create an attention metadata.
|
||||
prefill_attn_metadata = None
|
||||
decode_attn_metadata = None
|
||||
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
||||
prefill_attn_metadata = self.attn_backend.make_metadata(
|
||||
if metadata_dict:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
else:
|
||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
|
||||
attn_metadata = None
|
||||
pooling_metadata = PoolingMetadata(seq_groups=None,
|
||||
seq_data=None,
|
||||
prompt_lens=None)
|
||||
|
||||
# if it is a mixed batch, decode attn_metadata is broadcasted
|
||||
# separately.
|
||||
if batch_type == BatchType.MIXED:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
prefill_metadata=prefill_attn_metadata,
|
||||
decode_metadata=decode_attn_metadata,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||
lora_requests, lora_mapping, multi_modal_input)
|
||||
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
import time
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||
get_attn_backend)
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
@ -37,66 +35,38 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
]
|
||||
|
||||
|
||||
class PreparePromptMetadata(NamedTuple):
|
||||
input_tokens: List[int]
|
||||
input_positions: List[int]
|
||||
attn_metadata: Optional[AttentionMetadataPerStage]
|
||||
class ModelInput(NamedTuple):
|
||||
input_tokens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
attn_metadata: Optional[AttentionMetadata]
|
||||
seq_lens: List[int]
|
||||
query_lens: List[int]
|
||||
lora_index_mapping: List[int]
|
||||
lora_prompt_mapping: List[int]
|
||||
lora_mapping: Optional[LoRAMapping]
|
||||
lora_requests: Set[LoRARequest]
|
||||
multi_modal_input: Optional[torch.Tensor]
|
||||
slot_mapping: List[int]
|
||||
slot_mapping: torch.Tensor
|
||||
num_prefill_tokens: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return PreparePromptMetadata(
|
||||
input_tokens=[],
|
||||
input_positions=[],
|
||||
def empty(cls, device):
|
||||
return ModelInput(
|
||||
input_tokens=torch.empty(0, device=device),
|
||||
input_positions=torch.empty(0, device=device),
|
||||
attn_metadata=None,
|
||||
seq_lens=[],
|
||||
query_lens=[],
|
||||
lora_index_mapping=[],
|
||||
lora_prompt_mapping=[],
|
||||
lora_mapping=None,
|
||||
lora_requests=set(),
|
||||
multi_modal_input=None,
|
||||
slot_mapping=[],
|
||||
slot_mapping=torch.empty(0, device=device),
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=0,
|
||||
num_prefills=0,
|
||||
)
|
||||
|
||||
|
||||
class PrepareDecodeMetadata(NamedTuple):
|
||||
input_tokens: List[int]
|
||||
input_positions: List[int]
|
||||
attn_metadata: Optional[AttentionMetadata]
|
||||
lora_index_mapping: List[int]
|
||||
lora_prompt_mapping: List[int]
|
||||
lora_requests: Set[LoRARequest]
|
||||
slot_mapping: List[int]
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return PrepareDecodeMetadata(
|
||||
input_tokens=[],
|
||||
input_positions=[],
|
||||
attn_metadata=None,
|
||||
lora_index_mapping=[],
|
||||
lora_prompt_mapping=[],
|
||||
lora_requests=set(),
|
||||
slot_mapping=[],
|
||||
)
|
||||
|
||||
|
||||
# How batches are constructed.
|
||||
class BatchType(IntEnum):
|
||||
# Every batch is prefill.
|
||||
PREFILL = 0
|
||||
# Every batch is decode.
|
||||
DECODE = 1
|
||||
# Batch is a mixture of prefill and decode.
|
||||
MIXED = 2
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
def __init__(
|
||||
@ -216,10 +186,22 @@ class ModelRunner:
|
||||
block_size = self.block_size
|
||||
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
||||
|
||||
def _prepare_prompt(
|
||||
def _prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> PreparePromptMetadata:
|
||||
) -> ModelInput:
|
||||
"""Prepare the model input based on a given sequence group.
|
||||
|
||||
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||
|
||||
The result tensors and data structure also batches input in prefill
|
||||
-> decode order. For example,
|
||||
|
||||
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
|
||||
If cuda graph is required, this API automatically pads inputs.
|
||||
"""
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
@ -228,212 +210,16 @@ class ModelRunner:
|
||||
lora_requests: Set[LoRARequest] = set()
|
||||
|
||||
seq_lens: List[int] = []
|
||||
prefill_seq_lens: List[int] = []
|
||||
decode_seq_lens: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
query_lens: List[int] = []
|
||||
prefix_block_tables: List[List[int]] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return PreparePromptMetadata.empty()
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
if (self.scheduler_config is not None
|
||||
and self.scheduler_config.chunked_prefill_enabled
|
||||
and not (computed_block_nums is None
|
||||
or computed_block_nums == [])):
|
||||
raise RuntimeError(
|
||||
"chunked prefill cannot be used with prefix caching "
|
||||
"now.")
|
||||
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
# We should use get_len here because in case of preemption
|
||||
# it contains output tokens.
|
||||
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
|
||||
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
if computed_block_nums is not None and len(
|
||||
computed_block_nums) > 0 and self.sliding_window is None:
|
||||
# Prefix is not supported with sliding_window
|
||||
context_len = len(computed_block_nums) * self.block_size
|
||||
prompt_tokens = prompt_tokens[context_len:]
|
||||
prefix_block_tables.append(computed_block_nums)
|
||||
elif self.scheduler_config.chunked_prefill_enabled:
|
||||
if seq_group_metadata.block_tables is not None:
|
||||
# Prefill has chunked before.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
prefix_block_tables.append(block_table)
|
||||
else:
|
||||
# The first prefill.
|
||||
prefix_block_tables.append([])
|
||||
else:
|
||||
prefix_block_tables.append([])
|
||||
# Right now, prefill start is always 0. However, this
|
||||
# assumption can be changed once chunked prefill is introduced.
|
||||
assert context_len == 0
|
||||
|
||||
# actual prompt lens
|
||||
context_lens.append(context_len)
|
||||
query_lens.append(seq_len - context_len)
|
||||
|
||||
input_tokens.extend(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.extend(list(range(context_len, seq_len)))
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
|
||||
if lora_id > 0:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
|
||||
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
||||
lora_prompt_mapping.extend([lora_id] * (
|
||||
seq_len - context_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
multi_modal_input_list.append(
|
||||
seq_group_metadata.multi_modal_data.data)
|
||||
|
||||
if _is_block_tables_empty(seq_group_metadata.block_tables):
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
assert context_len == 0, (
|
||||
"Prefix caching is currently not supported with "
|
||||
"sliding window attention")
|
||||
start_idx = max(0, seq_len - self.sliding_window)
|
||||
|
||||
for i in range(context_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_seq_len = max(seq_lens)
|
||||
assert max_query_len > 0
|
||||
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
if multi_modal_input_list:
|
||||
assert self.vision_language_config, (
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||
dim=0).to(self.device)
|
||||
else:
|
||||
multi_modal_input = None
|
||||
|
||||
# Prepare prefix block tables
|
||||
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
||||
block_tables = make_tensor_with_pad(
|
||||
prefix_block_tables,
|
||||
max_len=max_prompt_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Query length can be shorter than key (i.e., prompt) when prefill
|
||||
# is chunked or prefix cached.
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=subquery_start_loc.dtype,
|
||||
out=subquery_start_loc[1:])
|
||||
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
use_cuda_graph=False,
|
||||
seq_start_loc=seq_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
block_tables=block_tables)
|
||||
else:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=subquery_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
return PreparePromptMetadata(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
lora_index_mapping=lora_index_mapping,
|
||||
lora_prompt_mapping=lora_prompt_mapping,
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_input=multi_modal_input,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> PrepareDecodeMetadata:
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
lora_index_mapping: List[int] = []
|
||||
lora_prompt_mapping: List[int] = []
|
||||
lora_requests: Set[LoRARequest] = set()
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
decode_only = True
|
||||
num_prefills = 0
|
||||
num_prefill_tokens = 0
|
||||
num_decode_tokens = 0
|
||||
|
||||
# The following fields are only for flashinfer
|
||||
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
||||
@ -454,60 +240,186 @@ class ModelRunner:
|
||||
paged_kv_last_page_len: List[int] = []
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return PrepareDecodeMetadata.empty()
|
||||
return ModelInput.empty(self.device)
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
|
||||
if lora_id > 0:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
for seq_id in seq_ids:
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
if (self.scheduler_config is not None
|
||||
and self.scheduler_config.chunked_prefill_enabled
|
||||
and not (computed_block_nums is None
|
||||
or computed_block_nums == [])):
|
||||
raise RuntimeError(
|
||||
"chunked prefill cannot be used with prefix caching "
|
||||
"now.")
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
if is_prompt:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
else:
|
||||
# get_num_computed_tokens is incorrect for spec decoding.
|
||||
# So, we should have a special logic here.
|
||||
# TODO(sang): Fix it.
|
||||
context_len = seq_data.get_len() - 1
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
seq_len = min(
|
||||
seq_data.get_len(),
|
||||
context_len + seq_group_metadata.token_chunk_size)
|
||||
if is_prompt:
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
else:
|
||||
# Optimization. get_token_ids requires the entire copy of
|
||||
# tokens.
|
||||
tokens = [seq_data.get_last_token_id()]
|
||||
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
seq_lens.append(seq_len)
|
||||
# Prefix cache was hit.
|
||||
# Prefix is not supported with sliding_window
|
||||
prefix_cache_hit = (computed_block_nums is not None
|
||||
and len(computed_block_nums) > 0
|
||||
and self.sliding_window is None
|
||||
and is_prompt)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
lora_index_mapping.append(lora_id)
|
||||
lora_prompt_mapping.append(lora_id)
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
if prefix_cache_hit:
|
||||
assert computed_block_nums is not None
|
||||
context_len = len(computed_block_nums) * self.block_size
|
||||
tokens = tokens[context_len:]
|
||||
if self.attn_backend.get_name() == "flash-attn":
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
# TODO(woosuk): This is a temporary fix. We should
|
||||
# provide a unified interface for different backends.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
else:
|
||||
block_table = computed_block_nums
|
||||
elif (self.scheduler_config.chunked_prefill_enabled
|
||||
or not is_prompt):
|
||||
if seq_group_metadata.block_tables is not None:
|
||||
# chunked prefill or decode
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
if self.sliding_window is not None:
|
||||
# chunked prefill doesn't support sliding window.
|
||||
assert (not self.scheduler_config.
|
||||
chunked_prefill_enabled)
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
paged_kv_indices.extend(block_table)
|
||||
paged_kv_indptr.append(paged_kv_indptr[-1] +
|
||||
len(block_table))
|
||||
last_page_len = seq_data.get_len(
|
||||
) % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
paged_kv_last_page_len.append(last_page_len)
|
||||
else:
|
||||
# Only happens when memory profiling runs.
|
||||
block_table = []
|
||||
else:
|
||||
# Prefill without chunked prefill or memory profiling.
|
||||
block_table = []
|
||||
block_tables.append(block_table)
|
||||
|
||||
paged_kv_indices.extend(block_table)
|
||||
paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
|
||||
last_page_len = seq_data.get_len() % self.block_size
|
||||
if last_page_len == 0:
|
||||
last_page_len = self.block_size
|
||||
paged_kv_last_page_len.append(last_page_len)
|
||||
# TODO(sang): This is a hack to make sliding window work with
|
||||
# paged attn. We can remove it if we make paged attn kernel
|
||||
# to properly handle slinding window attn.
|
||||
if (self.sliding_window is not None and not is_prompt):
|
||||
seq_len = min(seq_len, self.sliding_window)
|
||||
context_len = seq_len - 1
|
||||
|
||||
seq_lens.append(seq_len)
|
||||
context_lens.append(context_len)
|
||||
query_len = seq_len - context_len
|
||||
query_lens.append(query_len)
|
||||
input_tokens.extend(tokens)
|
||||
input_positions.extend(list(range(context_len, seq_len)))
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
|
||||
if is_prompt:
|
||||
assert len(seq_ids) == 1
|
||||
num_prefills += 1
|
||||
num_prefill_tokens += len(tokens)
|
||||
decode_only = False
|
||||
prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
assert query_len == 1, (
|
||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||
seq_len, context_len, query_len))
|
||||
num_decode_tokens += query_len
|
||||
decode_seq_lens.append(seq_len)
|
||||
|
||||
if lora_id > 0:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
|
||||
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
||||
lora_prompt_mapping.extend(
|
||||
[lora_id] *
|
||||
(seq_len -
|
||||
context_len if seq_group_metadata.sampling_params
|
||||
and seq_group_metadata.sampling_params.prompt_logprobs
|
||||
else 1))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
multi_modal_input_list.append(
|
||||
seq_group_metadata.multi_modal_data.data)
|
||||
|
||||
if _is_block_tables_empty(seq_group_metadata.block_tables):
|
||||
# During memory profiling, the block tables are not
|
||||
# initialized yet. In this case, we just use a dummy
|
||||
# slot mapping.
|
||||
# In embeddings, the block tables are {seq_id: None}.
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
|
||||
# Mask the [0, start_idx) tokens of the prompt with
|
||||
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
|
||||
# sliding_window). For example, if the prompt len is 10,
|
||||
# sliding window is 8, and block size is 4, the first two
|
||||
# tokens are masked and the slot mapping will be
|
||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
if is_prompt:
|
||||
assert context_len == 0, (
|
||||
"Prefix caching is currently not supported with "
|
||||
"sliding window attention")
|
||||
# It is an optimization. When it is decoding, it is always
|
||||
# 0. When prefill, we use it to not write slots to kv cache
|
||||
# to save memory.
|
||||
start_idx = max(0, query_len - self.sliding_window)
|
||||
|
||||
for i in range(context_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# vLLM uses cuda graph only for decoding requests.
|
||||
# See `capture_model` API for more details.
|
||||
# For decoding requests, batch_size == input_tokens.
|
||||
batch_size = len(input_tokens)
|
||||
max_seq_len = max(seq_lens)
|
||||
use_captured_graph = (not self.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_seq_len <= self.max_seq_len_to_capture)
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(decode_seq_lens, default=0)
|
||||
|
||||
# If cuda graph can be used, pad tensors accordingly.
|
||||
# See `capture_model` API for more details.
|
||||
# vLLM uses cuda graph only for decoding requests.
|
||||
use_captured_graph = (
|
||||
decode_only and not self.model_config.enforce_eager
|
||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||
and max_decode_seq_len <= self.max_seq_len_to_capture)
|
||||
if use_captured_graph:
|
||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||
assert graph_batch_size >= batch_size
|
||||
@ -519,18 +431,9 @@ class ModelRunner:
|
||||
block_tables.append([])
|
||||
lora_index_mapping.append(0)
|
||||
batch_size = graph_batch_size
|
||||
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
num_decode_tokens = batch_size
|
||||
|
||||
if use_captured_graph:
|
||||
# When using cuda-graph all these tensors should be
|
||||
# padded.
|
||||
assert seq_lens_tensor.shape[0] == len(input_tokens)
|
||||
assert seq_lens_tensor.shape[0] == len(input_positions)
|
||||
assert seq_lens_tensor.shape[0] == len(slot_mapping)
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
input_block_tables = self.graph_block_tables[:batch_size]
|
||||
@ -548,6 +451,57 @@ class ModelRunner:
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
if multi_modal_input_list:
|
||||
assert self.vision_language_config, (
|
||||
"Multi-modal inputs are only supported by "
|
||||
"vision language models.")
|
||||
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||
dim=0).to(self.device)
|
||||
else:
|
||||
multi_modal_input = None
|
||||
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
query_lens_tensor = torch.tensor(query_lens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=query_start_loc.dtype,
|
||||
out=query_start_loc[1:])
|
||||
|
||||
torch.cumsum(seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=seq_start_loc.dtype,
|
||||
out=seq_start_loc[1:])
|
||||
|
||||
input_tokens_tensor = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions_tensor = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
if self.attn_backend.get_name() == "flashinfer":
|
||||
if not hasattr(self, "flashinfer_workspace_buffer"):
|
||||
@ -555,53 +509,75 @@ class ModelRunner:
|
||||
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
|
||||
self.flashinfer_workspace_buffer = torch.empty(
|
||||
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
|
||||
paged_kv_indptr = torch.tensor(paged_kv_indptr,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
paged_kv_indices = torch.tensor(paged_kv_indices,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
|
||||
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
paged_kv_last_page_len_tensor = torch.tensor(
|
||||
paged_kv_last_page_len, dtype=torch.int, device=self.device)
|
||||
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
|
||||
self.model_config.dtype)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
use_cuda_graph=False,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
block_tables=block_tables,
|
||||
workspace_buffer=self.flashinfer_workspace_buffer,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
paged_kv_indptr=paged_kv_indptr_tensor,
|
||||
paged_kv_indices=paged_kv_indices_tensor,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
||||
num_qo_heads=self.model_config.get_num_attention_heads(
|
||||
self.parallel_config),
|
||||
num_kv_heads=self.model_config.get_num_kv_heads(
|
||||
self.parallel_config),
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
page_size=self.block_size,
|
||||
page_size=16,
|
||||
seq_start_loc=seq_start_loc,
|
||||
data_type=kv_cache_dtype)
|
||||
else:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
seq_lens=None,
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=None,
|
||||
max_seq_len=max_seq_len,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
max_query_len=max_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
)
|
||||
return PrepareDecodeMetadata(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
)
|
||||
else:
|
||||
lora_mapping = None
|
||||
|
||||
return ModelInput(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
lora_index_mapping=lora_index_mapping,
|
||||
lora_prompt_mapping=lora_prompt_mapping,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=query_lens,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_input=multi_modal_input,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
)
|
||||
|
||||
def prepare_input_tensors(
|
||||
@ -610,85 +586,25 @@ class ModelRunner:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||
if self.is_driver_worker:
|
||||
prefill_reqs = []
|
||||
decode_reqs = []
|
||||
for seq_group_meta in seq_group_metadata_list:
|
||||
if seq_group_meta.is_prompt:
|
||||
prefill_reqs.append(seq_group_meta)
|
||||
else:
|
||||
decode_reqs.append(seq_group_meta)
|
||||
|
||||
# Prepare input tensors.
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
prefill_attn_metadata,
|
||||
attn_metadata,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
lora_mapping,
|
||||
lora_requests,
|
||||
multi_modal_input,
|
||||
slot_mapping,
|
||||
) = self._prepare_prompt(prefill_reqs)
|
||||
(
|
||||
decode_input_tokens,
|
||||
decode_input_positions,
|
||||
decode_attn_metadata,
|
||||
decode_lora_index_mapping,
|
||||
decode_lora_prompt_mapping,
|
||||
decode_lora_requests,
|
||||
decode_slot_mapping,
|
||||
) = self._prepare_decode(decode_reqs)
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
) = self._prepare_model_input(seq_group_metadata_list)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.pin_memory)
|
||||
|
||||
if not self.scheduler_config.chunked_prefill_enabled:
|
||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||
|
||||
num_prefills = len(seq_lens)
|
||||
num_prefill_tokens = len(input_tokens)
|
||||
num_decode_tokens = len(decode_input_tokens)
|
||||
|
||||
# Coalesce tensors. Note that attn_metadata is currently not
|
||||
# coalesced for simplicity.
|
||||
input_tokens.extend(decode_input_tokens)
|
||||
input_positions.extend(decode_input_positions)
|
||||
slot_mapping.extend(decode_slot_mapping)
|
||||
lora_index_mapping.extend(decode_lora_index_mapping)
|
||||
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
||||
lora_requests.update(decode_lora_requests)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
)
|
||||
else:
|
||||
lora_mapping = None
|
||||
|
||||
# Broadcast the metadata.
|
||||
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
||||
# If it only contains 1 type, it triggers a single broadcast.
|
||||
if (prefill_attn_metadata is not None
|
||||
and decode_attn_metadata is not None):
|
||||
batch_type = BatchType.MIXED
|
||||
elif prefill_attn_metadata is not None:
|
||||
batch_type = BatchType.PREFILL
|
||||
else:
|
||||
batch_type = BatchType.DECODE
|
||||
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
@ -701,46 +617,24 @@ class ModelRunner:
|
||||
"num_decode_tokens": num_decode_tokens,
|
||||
"slot_mapping": slot_mapping,
|
||||
"num_prefills": num_prefills,
|
||||
"batch_type": batch_type,
|
||||
}
|
||||
if prefill_attn_metadata is not None:
|
||||
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
||||
else:
|
||||
assert decode_attn_metadata is not None
|
||||
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
||||
if attn_metadata:
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
|
||||
# Broadcast decode attn metadata for mixed batch type.
|
||||
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
||||
# We can potentially reduce the overhead by coelescing tensors.
|
||||
if batch_type == BatchType.MIXED:
|
||||
assert decode_attn_metadata is not None
|
||||
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
slot_mapping = metadata_dict.pop("slot_mapping")
|
||||
num_prefills = metadata_dict.pop("num_prefills")
|
||||
selected_token_indices = metadata_dict.pop(
|
||||
"selected_token_indices")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
||||
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
||||
batch_type = metadata_dict.pop("batch_type")
|
||||
|
||||
# Create an attention metadata.
|
||||
prefill_attn_metadata = None
|
||||
decode_attn_metadata = None
|
||||
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
||||
prefill_attn_metadata = self.attn_backend.make_metadata(
|
||||
if metadata_dict:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
else:
|
||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
attn_metadata = None
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
@ -748,22 +642,6 @@ class ModelRunner:
|
||||
num_prompts=0,
|
||||
)
|
||||
|
||||
# if it is a mixed batch, decode attn_metadata is broadcasted
|
||||
# separately.
|
||||
if batch_type == BatchType.MIXED:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
prefill_metadata=prefill_attn_metadata,
|
||||
decode_metadata=decode_attn_metadata,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, lora_requests, lora_mapping,
|
||||
multi_modal_input)
|
||||
@ -954,25 +832,21 @@ class ModelRunner:
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
# Create dummy attn_metadata.
|
||||
decode_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens[:batch_size],
|
||||
max_query_len=None,
|
||||
max_seq_len=self.max_seq_len_to_capture,
|
||||
subquery_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
attn_metadata = AttentionMetadata(
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
prefill_metadata=None,
|
||||
decode_metadata=decode_metadata,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens[:batch_size],
|
||||
max_query_len=None,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user