mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 10:09:08 +08:00
[Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)
This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend. It also refactors subquery_start_loc which was not refactored in the previous PR
This commit is contained in:
parent
8a7cc254a0
commit
65bf2ac165
@ -58,19 +58,25 @@ def test_prepare_prompt(batch_size):
|
|||||||
expected_selected_token_indices.append(selected_token_start_idx +
|
expected_selected_token_indices.append(selected_token_start_idx +
|
||||||
seq_len - 1)
|
seq_len - 1)
|
||||||
selected_token_start_idx += seq_len
|
selected_token_start_idx += seq_len
|
||||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
input_tokens = model_input.input_tokens
|
||||||
|
input_positions = model_input.input_positions
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
return_seq_lens = model_input.seq_lens
|
||||||
|
slot_mapping = model_input.slot_mapping
|
||||||
assert return_seq_lens == seq_lens
|
assert return_seq_lens == seq_lens
|
||||||
assert len(slot_mapping) == len(input_tokens)
|
assert len(slot_mapping) == len(input_tokens)
|
||||||
|
|
||||||
# Verify input metadata is correct for prompts.
|
# Verify input metadata is correct for prompts.
|
||||||
device = model_runner.device
|
device = model_runner.device
|
||||||
assert attn_metadata.is_prompt is True
|
assert attn_metadata.num_prefills > 0
|
||||||
|
assert attn_metadata.num_decode_tokens == 0
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
attn_metadata.seq_lens_tensor,
|
attn_metadata.seq_lens_tensor,
|
||||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||||
assert attn_metadata.seq_lens == seq_lens
|
assert attn_metadata.seq_lens == seq_lens
|
||||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
|
||||||
|
assert attn_metadata.max_decode_seq_len == 0
|
||||||
|
|
||||||
# Test subquery start locs.
|
# Test subquery start locs.
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
@ -79,11 +85,11 @@ def test_prepare_prompt(batch_size):
|
|||||||
start_idx += seq_len
|
start_idx += seq_len
|
||||||
start_loc.append(start_idx)
|
start_loc.append(start_idx)
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
attn_metadata.subquery_start_loc,
|
attn_metadata.query_start_loc,
|
||||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||||
|
|
||||||
# Test seq start locs. Note that for normal prefill it is
|
# Test seq start locs. Note that for normal prefill it is
|
||||||
# equivalent to subquery_start_loc.
|
# equivalent to query_start_loc.
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
seq_start_loc = [start_idx]
|
seq_start_loc = [start_idx]
|
||||||
for seq_len in seq_lens:
|
for seq_len in seq_lens:
|
||||||
@ -123,7 +129,7 @@ def test_prepare_prompt(batch_size):
|
|||||||
device=actual.device,
|
device=actual.device,
|
||||||
dtype=actual.dtype)
|
dtype=actual.dtype)
|
||||||
torch.testing.assert_close(actual, expected)
|
torch.testing.assert_close(actual, expected)
|
||||||
assert input_tokens == input_positions
|
torch.allclose(input_tokens, input_positions)
|
||||||
|
|
||||||
actual = sampling_metadata.selected_token_indices
|
actual = sampling_metadata.selected_token_indices
|
||||||
expected = torch.tensor(expected_selected_token_indices,
|
expected = torch.tensor(expected_selected_token_indices,
|
||||||
@ -144,14 +150,18 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
enable_chunked_prefill=False,
|
enable_chunked_prefill=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_lens = []
|
context_lens = []
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
|
# Assume each seq group finishes prefill.
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
context_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_lens.append(seq_len)
|
context_lens.append(context_len)
|
||||||
seq_data = list(range(seq_len))
|
seq_data = list(range(context_len))
|
||||||
seq_data = SequenceData(seq_data)
|
seq_data = SequenceData(seq_data)
|
||||||
|
seq_data.update_num_computed_tokens(context_len)
|
||||||
|
# Append one token ID since prefill is finished.
|
||||||
|
seq_data.append_token_id(1, 0)
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=False,
|
is_prompt=False,
|
||||||
@ -162,18 +172,45 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
assert seq_group_metadata.token_chunk_size == 1
|
assert seq_group_metadata.token_chunk_size == 1
|
||||||
seq_group_metadata_list.append(seq_group_metadata)
|
seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||||
model_runner._prepare_decode(seq_group_metadata_list))
|
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||||
|
model_input.input_tokens, model_input.input_positions,
|
||||||
|
model_input.attn_metadata, model_input.slot_mapping)
|
||||||
assert len(slot_mapping) == len(input_tokens)
|
assert len(slot_mapping) == len(input_tokens)
|
||||||
|
|
||||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||||
# Verify input metadata is correct for prompts.
|
# Verify input metadata is correct for prompts.
|
||||||
device = model_runner.device
|
device = model_runner.device
|
||||||
assert attn_metadata.is_prompt is False
|
assert attn_metadata.num_prefills == 0
|
||||||
assert attn_metadata.seq_lens is None
|
assert attn_metadata.num_prefill_tokens == 0
|
||||||
assert attn_metadata.subquery_start_loc is None
|
seq_lens = [context_len + 1 for context_len in context_lens]
|
||||||
assert attn_metadata.seq_start_loc is None
|
# seq_lens are padded to expected_bs
|
||||||
assert attn_metadata.max_seq_len == max(seq_lens)
|
for _ in range(expected_bs - len(seq_lens)):
|
||||||
|
seq_lens.append(1)
|
||||||
|
assert attn_metadata.seq_lens == seq_lens
|
||||||
|
start_idx = 0
|
||||||
|
start_loc = [start_idx]
|
||||||
|
for _ in context_lens:
|
||||||
|
# decode has only 1 token for query.
|
||||||
|
start_idx += 1
|
||||||
|
start_loc.append(start_idx)
|
||||||
|
assert torch.allclose(
|
||||||
|
attn_metadata.query_start_loc,
|
||||||
|
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
seq_start_loc = [start_idx]
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
start_idx += seq_len
|
||||||
|
seq_start_loc.append(start_idx)
|
||||||
|
assert torch.allclose(
|
||||||
|
attn_metadata.seq_start_loc,
|
||||||
|
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
attn_metadata.context_lens_tensor,
|
||||||
|
torch.tensor(context_lens, dtype=torch.int, device=device))
|
||||||
|
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
||||||
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
||||||
@ -185,23 +222,23 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
# It is padded up to
|
# It is padded up to
|
||||||
assert attn_metadata.block_tables.shape[1] == (
|
assert attn_metadata.block_tables.shape[1] == (
|
||||||
model_runner.get_max_block_per_batch())
|
model_runner.get_max_block_per_batch())
|
||||||
# Cuda graph should not be used for prerill.
|
|
||||||
assert attn_metadata.use_cuda_graph is True
|
assert attn_metadata.use_cuda_graph is True
|
||||||
|
|
||||||
assert len(input_tokens) == expected_bs
|
assert len(input_tokens) == expected_bs
|
||||||
assert len(input_positions) == expected_bs
|
assert len(input_positions) == expected_bs
|
||||||
assert input_tokens == input_positions
|
torch.allclose(input_tokens, input_positions)
|
||||||
|
|
||||||
# Verify Sampling
|
# Verify Sampling
|
||||||
expected_selected_token_indices = []
|
expected_selected_token_indices = []
|
||||||
selected_token_start_idx = 0
|
selected_token_start_idx = 0
|
||||||
for seq_len in seq_lens:
|
for _ in context_lens:
|
||||||
expected_selected_token_indices.append(selected_token_start_idx)
|
expected_selected_token_indices.append(selected_token_start_idx)
|
||||||
selected_token_start_idx += 1
|
selected_token_start_idx += 1
|
||||||
sampling_metadata = SamplingMetadata.prepare(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
query_lens=seq_lens,
|
# query lens is all 1 for decode.
|
||||||
|
query_lens=[1 for _ in range(len(context_lens))],
|
||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
pin_memory=model_runner.pin_memory)
|
pin_memory=model_runner.pin_memory)
|
||||||
actual = sampling_metadata.selected_token_indices
|
actual = sampling_metadata.selected_token_indices
|
||||||
@ -220,15 +257,27 @@ def test_empty_seq_group():
|
|||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
)
|
)
|
||||||
seq_group_metadata_list = []
|
seq_group_metadata_list = []
|
||||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||||
model_runner._prepare_decode(seq_group_metadata_list))
|
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||||
|
model_input.input_tokens,
|
||||||
|
model_input.input_positions,
|
||||||
|
model_input.attn_metadata,
|
||||||
|
model_input.slot_mapping,
|
||||||
|
)
|
||||||
assert len(input_tokens) == 0
|
assert len(input_tokens) == 0
|
||||||
assert len(input_positions) == 0
|
assert len(input_positions) == 0
|
||||||
assert attn_metadata is None
|
assert attn_metadata is None
|
||||||
assert len(slot_mapping) == 0
|
assert len(slot_mapping) == 0
|
||||||
|
|
||||||
(input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _,
|
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||||
_, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
(input_tokens, input_positions, attn_metadata, slot_mapping,
|
||||||
|
return_seq_lens) = (
|
||||||
|
model_input.input_tokens,
|
||||||
|
model_input.input_positions,
|
||||||
|
model_input.attn_metadata,
|
||||||
|
model_input.slot_mapping,
|
||||||
|
model_input.seq_lens,
|
||||||
|
)
|
||||||
assert len(input_tokens) == 0
|
assert len(input_tokens) == 0
|
||||||
assert len(input_positions) == 0
|
assert len(input_positions) == 0
|
||||||
assert attn_metadata is None
|
assert attn_metadata is None
|
||||||
@ -285,9 +334,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
# Add decode requests
|
# Add decode requests
|
||||||
for i in range(prefill_batch_size, batch_size):
|
for i in range(prefill_batch_size, batch_size):
|
||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
context_len = i % (model_runner.block_size - 1) + 1
|
||||||
prompt_toks = list(range(seq_len))
|
prompt_toks = list(range(context_len))
|
||||||
seq_data = SequenceData(prompt_toks)
|
seq_data = SequenceData(prompt_toks)
|
||||||
|
seq_data.append_token_id(1, 0)
|
||||||
|
seq_data.update_num_computed_tokens(context_len)
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=False,
|
is_prompt=False,
|
||||||
@ -308,23 +359,17 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
||||||
assert len(input_positions) == len(input_tokens)
|
assert len(input_positions) == len(input_tokens)
|
||||||
assert attn_metadata.num_prefills == prefill_batch_size
|
assert attn_metadata.num_prefills == prefill_batch_size
|
||||||
if enforce_eager:
|
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
|
||||||
else:
|
|
||||||
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
|
|
||||||
decode_batch_size)
|
|
||||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||||
|
|
||||||
# Verify attn metadata is consistent. We don't need to test individual
|
# Verify attn metadata is consistent. We don't need to test individual
|
||||||
# values here because they are tested above.
|
# values here because they are tested above.
|
||||||
prefill_meta = model_runner._prepare_prompt(
|
attn_metadata = model_runner._prepare_model_input(
|
||||||
prefill_metadata_list).attn_metadata
|
seq_group_metadata_list).attn_metadata
|
||||||
decode_meta = model_runner._prepare_decode(
|
|
||||||
decode_metadata_list).attn_metadata
|
|
||||||
|
|
||||||
for attr_expected, attr_actual in zip(vars(prefill_meta),
|
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
|
||||||
vars(prefill_meta_actual)):
|
vars(prefill_meta_actual)):
|
||||||
assert attr_expected[1] == attr_actual[1]
|
assert attr_expected[1] == attr_actual[1]
|
||||||
for attr_expected, attr_actual in zip(vars(decode_meta),
|
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
|
||||||
vars(decode_meta_actual)):
|
vars(decode_meta_actual)):
|
||||||
assert attr_expected[1] == attr_actual[1]
|
assert attr_expected[1] == attr_actual[1]
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
AttentionMetadata,
|
AttentionMetadata)
|
||||||
AttentionMetadataPerStage)
|
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
|
||||||
@ -8,6 +7,6 @@ __all__ = [
|
|||||||
"Attention",
|
"Attention",
|
||||||
"AttentionBackend",
|
"AttentionBackend",
|
||||||
"AttentionMetadata",
|
"AttentionMetadata",
|
||||||
"AttentionMetadataPerStage",
|
"Attention",
|
||||||
"get_attn_backend",
|
"get_attn_backend",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class AttentionBackend(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage":
|
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -53,8 +53,34 @@ class AttentionBackend(ABC):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AttentionMetadataPerStage:
|
class AttentionMetadata:
|
||||||
"""Attention metadata for a specific stage. I.e., prefill or decode."""
|
"""Attention metadata for prefill and decode batched together."""
|
||||||
|
# Total number of prefill requests.
|
||||||
|
num_prefills: int
|
||||||
|
# Number of prefill tokens.
|
||||||
|
num_prefill_tokens: int
|
||||||
|
# Number of decode tokens. Note that it is equivalent to the number of
|
||||||
|
# decode requests.
|
||||||
|
num_decode_tokens: int
|
||||||
|
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||||
|
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||||
|
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||||
|
# in block 0, and 1st slot in block 1, respectively.
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
|
||||||
|
"""Return the attention metadata that's required to run prefill
|
||||||
|
attention."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def decode_metadata(self) -> Optional["AttentionMetadata"]:
|
||||||
|
"""Return the attention metadata that's required to run decode
|
||||||
|
attention."""
|
||||||
|
pass
|
||||||
|
|
||||||
def asdict_zerocopy(self,
|
def asdict_zerocopy(self,
|
||||||
skip_fields: Optional[Set[str]] = None
|
skip_fields: Optional[Set[str]] = None
|
||||||
@ -70,40 +96,10 @@ class AttentionMetadataPerStage:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=AttentionMetadataPerStage)
|
T = TypeVar("T", bound=AttentionMetadata)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class AttentionImpl(ABC, Generic[T]):
|
||||||
class AttentionMetadata(Generic[T]):
|
|
||||||
"""Attention metadata for prefill and decode batched together."""
|
|
||||||
# Total number of prefill requests.
|
|
||||||
num_prefills: int
|
|
||||||
# Number of prefill tokens.
|
|
||||||
num_prefill_tokens: int
|
|
||||||
# Number of decode tokens. Note that it is equivalent to the number of
|
|
||||||
# decode requests.
|
|
||||||
num_decode_tokens: int
|
|
||||||
# The attention metadata for prefill requests in a batch.
|
|
||||||
# None if there's no prefill requests in a batch.
|
|
||||||
prefill_metadata: Optional[T]
|
|
||||||
# The attention metadata for decode requests in a batch.
|
|
||||||
# None if there's no decode requests in a batch.
|
|
||||||
decode_metadata: Optional[T]
|
|
||||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
|
||||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
|
||||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
|
||||||
# in block 0, and 1st slot in block 1, respectively.
|
|
||||||
slot_mapping: torch.Tensor
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.num_prefill_tokens > 0:
|
|
||||||
assert self.num_prefills > 0
|
|
||||||
assert self.prefill_metadata is not None
|
|
||||||
if self.num_decode_tokens > 0:
|
|
||||||
assert self.decode_metadata is not None
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionImpl(ABC):
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -125,7 +121,7 @@ class AttentionImpl(ABC):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: T,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -11,8 +11,7 @@ import torch
|
|||||||
from vllm_flash_attn import flash_attn_varlen_func
|
from vllm_flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata,
|
AttentionMetadata)
|
||||||
AttentionMetadataPerStage)
|
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
|
|
||||||
@ -58,8 +57,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashAttentionMetadata(AttentionMetadataPerStage,
|
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
PagedAttentionMetadata):
|
|
||||||
"""Metadata for FlashAttentionBackend.
|
"""Metadata for FlashAttentionBackend.
|
||||||
|
|
||||||
NOTE: Any python object stored here is not updated when it is
|
NOTE: Any python object stored here is not updated when it is
|
||||||
@ -67,9 +65,6 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
|
|||||||
dynamically, it should be stored in tensor. The tensor has to be
|
dynamically, it should be stored in tensor. The tensor has to be
|
||||||
updated from `CUDAGraphRunner.forward` API.
|
updated from `CUDAGraphRunner.forward` API.
|
||||||
"""
|
"""
|
||||||
# Currently, input sequences can only contain all prompts
|
|
||||||
# or all decoding. True if all sequences are prompts.
|
|
||||||
is_prompt: bool
|
|
||||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||||
# the computed tokens + new tokens None if it is a decoding.
|
# the computed tokens + new tokens None if it is a decoding.
|
||||||
seq_lens: Optional[List[int]]
|
seq_lens: Optional[List[int]]
|
||||||
@ -84,14 +79,18 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
|
|||||||
# |-------------------- seq_len ----------------------|
|
# |-------------------- seq_len ----------------------|
|
||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
# Maximum query length in the batch.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int]
|
max_query_len: Optional[int]
|
||||||
# Maximum sequence length in the batch.
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
max_seq_len: Optional[int]
|
# requests only.
|
||||||
|
max_prefill_seq_len: int
|
||||||
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||||
|
# requests only.
|
||||||
|
max_decode_seq_len: int
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
# is [4, 6], it is [0, 4, 10].
|
# is [4, 6], it is [0, 4, 10].
|
||||||
subquery_start_loc: Optional[torch.Tensor]
|
query_start_loc: Optional[torch.Tensor]
|
||||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
# [4, 6], it is [0, 4, 10].
|
# [4, 6], it is [0, 4, 10].
|
||||||
@ -105,6 +104,70 @@ class FlashAttentionMetadata(AttentionMetadataPerStage,
|
|||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
use_cuda_graph: bool
|
use_cuda_graph: bool
|
||||||
|
|
||||||
|
_cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None
|
||||||
|
_cached_decode_metadata: Optional["FlashAttentionMetadata"] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||||
|
if self.num_prefills == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_prefill_metadata is not None:
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
assert self.seq_lens is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
assert self.query_start_loc is not None
|
||||||
|
assert self.context_lens_tensor is not None
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.seq_start_loc is not None
|
||||||
|
|
||||||
|
self._cached_prefill_metadata = FlashAttentionMetadata(
|
||||||
|
num_prefills=self.num_prefills,
|
||||||
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||||
|
seq_lens=self.seq_lens[:self.num_prefills],
|
||||||
|
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||||
|
max_query_len=self.max_query_len,
|
||||||
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
|
max_decode_seq_len=0,
|
||||||
|
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||||
|
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||||
|
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||||
|
block_tables=self.block_tables[:self.num_prefills],
|
||||||
|
use_cuda_graph=False,
|
||||||
|
)
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_decode_metadata is not None:
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
|
||||||
|
self._cached_decode_metadata = FlashAttentionMetadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
|
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||||
|
seq_lens=None,
|
||||||
|
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||||
|
max_query_len=None,
|
||||||
|
max_prefill_seq_len=0,
|
||||||
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
|
query_start_loc=None,
|
||||||
|
seq_start_loc=None,
|
||||||
|
context_lens_tensor=None,
|
||||||
|
block_tables=self.block_tables[self.num_prefills:],
|
||||||
|
use_cuda_graph=self.use_cuda_graph,
|
||||||
|
)
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionImpl(AttentionImpl):
|
class FlashAttentionImpl(AttentionImpl):
|
||||||
"""
|
"""
|
||||||
@ -168,7 +231,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
attn_metadata: FlashAttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
@ -228,8 +291,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
v=value,
|
v=value,
|
||||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
max_seqlen_q=prefill_meta.max_seq_len,
|
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||||
max_seqlen_k=prefill_meta.max_seq_len,
|
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
@ -249,7 +312,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
prefill_meta.block_tables,
|
prefill_meta.block_tables,
|
||||||
prefill_meta.subquery_start_loc,
|
prefill_meta.query_start_loc,
|
||||||
prefill_meta.seq_lens_tensor,
|
prefill_meta.seq_lens_tensor,
|
||||||
prefill_meta.context_lens_tensor,
|
prefill_meta.context_lens_tensor,
|
||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
@ -264,7 +327,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
decode_meta.block_tables,
|
decode_meta.block_tables,
|
||||||
decode_meta.seq_lens_tensor,
|
decode_meta.seq_lens_tensor,
|
||||||
decode_meta.max_seq_len,
|
decode_meta.max_decode_seq_len,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
|
|||||||
@ -8,8 +8,7 @@ from vllm_flash_attn import flash_attn_varlen_func
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata,
|
AttentionMetadata)
|
||||||
AttentionMetadataPerStage)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferBackend(AttentionBackend):
|
class FlashInferBackend(AttentionBackend):
|
||||||
@ -56,9 +55,10 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FlashInferMetadata(AttentionMetadataPerStage):
|
class FlashInferMetadata(AttentionMetadata):
|
||||||
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
is_prompt: bool
|
# requests only.
|
||||||
|
max_prefill_seq_len: int
|
||||||
|
|
||||||
use_cuda_graph: bool = False
|
use_cuda_graph: bool = False
|
||||||
|
|
||||||
@ -67,7 +67,6 @@ class FlashInferMetadata(AttentionMetadataPerStage):
|
|||||||
# Metadata for the prefill stage since we still
|
# Metadata for the prefill stage since we still
|
||||||
# use flash attention for prefill.
|
# use flash attention for prefill.
|
||||||
seq_start_loc: Optional[torch.Tensor] = None
|
seq_start_loc: Optional[torch.Tensor] = None
|
||||||
max_seq_len: Optional[int] = None
|
|
||||||
block_tables: Optional[torch.Tensor] = None
|
block_tables: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Metadata for the decode stage
|
# Metadata for the decode stage
|
||||||
@ -113,7 +112,8 @@ class FlashInferMetadata(AttentionMetadataPerStage):
|
|||||||
# When using flashinfer, we are also creating the FlashInferMetadata,
|
# When using flashinfer, we are also creating the FlashInferMetadata,
|
||||||
# which will also call post_init by default, here we want to skip the
|
# which will also call post_init by default, here we want to skip the
|
||||||
# post_init if it's the prefill phase.
|
# post_init if it's the prefill phase.
|
||||||
if not self.is_prompt:
|
if self.num_prefills == 0:
|
||||||
|
assert self.num_decode_tokens > 0
|
||||||
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
self.workspace_buffer, "NHD")
|
self.workspace_buffer, "NHD")
|
||||||
self.decode_wrapper.begin_forward(
|
self.decode_wrapper.begin_forward(
|
||||||
@ -138,6 +138,24 @@ class FlashInferMetadata(AttentionMetadataPerStage):
|
|||||||
skip_fields.add('decode_wrapper')
|
skip_fields.add('decode_wrapper')
|
||||||
return super().asdict_zerocopy(skip_fields)
|
return super().asdict_zerocopy(skip_fields)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||||
|
# Currently chunked prefill is not supported
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
assert self.num_prefills > 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
|
||||||
|
# Currently chunked prefill is not supported
|
||||||
|
if self.num_prefills > 0:
|
||||||
|
assert self.num_decode_tokens == 0
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class FlashInferImpl(AttentionImpl):
|
class FlashInferImpl(AttentionImpl):
|
||||||
|
|
||||||
@ -172,7 +190,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata[FlashInferMetadata],
|
attn_metadata: FlashInferMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_scale == 1.0
|
assert kv_scale == 1.0
|
||||||
@ -208,8 +226,8 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
v=value,
|
v=value,
|
||||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
max_seqlen_q=prefill_meta.max_seq_len,
|
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||||
max_seqlen_k=prefill_meta.max_seq_len,
|
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
|
|||||||
@ -6,8 +6,7 @@ import torch
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata,
|
AttentionMetadata)
|
||||||
AttentionMetadataPerStage)
|
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -56,8 +55,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
PagedAttentionMetadata):
|
|
||||||
"""Metadata for FlashAttentionBackend.
|
"""Metadata for FlashAttentionBackend.
|
||||||
|
|
||||||
NOTE: Any python object stored here is not updated when it is
|
NOTE: Any python object stored here is not updated when it is
|
||||||
@ -65,9 +63,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
|||||||
dynamically, it should be stored in tensor. The tensor has to be
|
dynamically, it should be stored in tensor. The tensor has to be
|
||||||
updated from `CUDAGraphRunner.forward` API.
|
updated from `CUDAGraphRunner.forward` API.
|
||||||
"""
|
"""
|
||||||
# Currently, input sequences can only contain all prompts
|
|
||||||
# or all decoding. True if all sequences are prompts.
|
|
||||||
is_prompt: bool
|
|
||||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||||
# the computed tokens + new tokens None if it is a decoding.
|
# the computed tokens + new tokens None if it is a decoding.
|
||||||
seq_lens: Optional[List[int]]
|
seq_lens: Optional[List[int]]
|
||||||
@ -82,14 +77,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
|||||||
# |-------------------- seq_len ----------------------|
|
# |-------------------- seq_len ----------------------|
|
||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
# Maximum query length in the batch.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int]
|
max_query_len: Optional[int]
|
||||||
# Maximum sequence length in the batch.
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
max_seq_len: Optional[int]
|
# requests only.
|
||||||
|
max_prefill_seq_len: int
|
||||||
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||||
|
# requests only.
|
||||||
|
max_decode_seq_len: int
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
# is [4, 6], it is [0, 4, 10].
|
# is [4, 6], it is [0, 4, 10].
|
||||||
subquery_start_loc: Optional[torch.Tensor]
|
query_start_loc: Optional[torch.Tensor]
|
||||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
# [4, 6], it is [0, 4, 10].
|
# [4, 6], it is [0, 4, 10].
|
||||||
@ -102,6 +101,69 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
|||||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||||
# so far).
|
# so far).
|
||||||
context_lens_tensor: Optional[torch.Tensor]
|
context_lens_tensor: Optional[torch.Tensor]
|
||||||
|
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||||
|
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||||
|
if self.num_prefills == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_prefill_metadata is not None:
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
assert self.seq_lens is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
assert self.query_start_loc is not None
|
||||||
|
assert self.context_lens_tensor is not None
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.seq_start_loc is not None
|
||||||
|
|
||||||
|
self._cached_prefill_metadata = ROCmFlashAttentionMetadata(
|
||||||
|
num_prefills=self.num_prefills,
|
||||||
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||||
|
seq_lens=self.seq_lens[:self.num_prefills],
|
||||||
|
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||||
|
max_query_len=self.max_query_len,
|
||||||
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
|
max_decode_seq_len=0,
|
||||||
|
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||||
|
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
|
||||||
|
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||||
|
block_tables=self.block_tables[:self.num_prefills],
|
||||||
|
use_cuda_graph=False,
|
||||||
|
)
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_decode_metadata is not None:
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
|
||||||
|
self._cached_decode_metadata = ROCmFlashAttentionMetadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
|
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||||
|
seq_lens=None,
|
||||||
|
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||||
|
max_query_len=None,
|
||||||
|
max_prefill_seq_len=0,
|
||||||
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
|
query_start_loc=None,
|
||||||
|
seq_start_loc=None,
|
||||||
|
context_lens_tensor=None,
|
||||||
|
block_tables=self.block_tables[self.num_prefills:],
|
||||||
|
use_cuda_graph=self.use_cuda_graph,
|
||||||
|
)
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
|
||||||
class ROCmFlashAttentionImpl(AttentionImpl):
|
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||||
@ -198,7 +260,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
|
attn_metadata: ROCmFlashAttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
@ -266,8 +328,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
None,
|
None,
|
||||||
prefill_meta.seq_start_loc,
|
prefill_meta.seq_start_loc,
|
||||||
prefill_meta.seq_start_loc,
|
prefill_meta.seq_start_loc,
|
||||||
prefill_meta.max_seq_len,
|
prefill_meta.max_prefill_seq_len,
|
||||||
prefill_meta.max_seq_len,
|
prefill_meta.max_prefill_seq_len,
|
||||||
True,
|
True,
|
||||||
self.scale,
|
self.scale,
|
||||||
)
|
)
|
||||||
@ -290,8 +352,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
v=value,
|
v=value,
|
||||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||||
max_seqlen_q=prefill_meta.max_seq_len,
|
max_seqlen_q=prefill_meta.max_prefill_seq_len,
|
||||||
max_seqlen_k=prefill_meta.max_seq_len,
|
max_seqlen_k=prefill_meta.max_prefill_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
@ -308,7 +370,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
prefill_meta.block_tables,
|
prefill_meta.block_tables,
|
||||||
prefill_meta.subquery_start_loc,
|
prefill_meta.query_start_loc,
|
||||||
prefill_meta.seq_lens_tensor,
|
prefill_meta.seq_lens_tensor,
|
||||||
prefill_meta.context_lens_tensor,
|
prefill_meta.context_lens_tensor,
|
||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
@ -324,7 +386,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
decode_meta.block_tables,
|
decode_meta.block_tables,
|
||||||
decode_meta.seq_lens_tensor,
|
decode_meta.seq_lens_tensor,
|
||||||
decode_meta.max_seq_len,
|
decode_meta.max_decode_seq_len,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
|
|||||||
@ -7,8 +7,7 @@ import torch
|
|||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata,
|
AttentionMetadata)
|
||||||
AttentionMetadataPerStage)
|
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
|
|
||||||
@ -54,8 +53,7 @@ class TorchSDPABackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
|
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
AttentionMetadataPerStage):
|
|
||||||
"""Metadata for TorchSDPABackend.
|
"""Metadata for TorchSDPABackend.
|
||||||
"""
|
"""
|
||||||
# Currently, input sequences can only contain all prompts
|
# Currently, input sequences can only contain all prompts
|
||||||
@ -72,8 +70,26 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
|
|||||||
# will not appear in the __repr__ and __init__
|
# will not appear in the __repr__ and __init__
|
||||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||||
|
# Currently chunked prefill is not supported
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
assert self.num_prefills > 0
|
||||||
|
return self
|
||||||
|
|
||||||
class TorchSDPABackendImpl(AttentionImpl):
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["TorchSDPAMetadata"]:
|
||||||
|
# Currently chunked prefill is not supported
|
||||||
|
if self.num_prefills > 0:
|
||||||
|
assert self.num_decode_tokens == 0
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -200,7 +216,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.block_tables,
|
attn_metadata.block_tables,
|
||||||
attn_metadata.seq_lens_tensor,
|
attn_metadata.seq_lens_tensor,
|
||||||
attn_metadata.max_seq_len,
|
attn_metadata.max_decode_seq_len,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
|
|||||||
@ -9,8 +9,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
|||||||
LowerTriangularMaskWithTensorBias)
|
LowerTriangularMaskWithTensorBias)
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata,
|
AttentionMetadata)
|
||||||
AttentionMetadataPerStage)
|
|
||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -59,7 +58,7 @@ class XFormersBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
"""Metadata for XFormersbackend.
|
"""Metadata for XFormersbackend.
|
||||||
|
|
||||||
NOTE: Any python object stored here is not updated when it is
|
NOTE: Any python object stored here is not updated when it is
|
||||||
@ -67,9 +66,6 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
|||||||
dynamically, it should be stored in tensor. The tensor has to be
|
dynamically, it should be stored in tensor. The tensor has to be
|
||||||
updated from `CUDAGraphRunner.forward` API.
|
updated from `CUDAGraphRunner.forward` API.
|
||||||
"""
|
"""
|
||||||
# Currently, input sequences can only contain all prompts
|
|
||||||
# or all decoding. True if all sequences are prompts.
|
|
||||||
is_prompt: bool
|
|
||||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||||
# the computed tokens + new tokens None if it is a decoding.
|
# the computed tokens + new tokens None if it is a decoding.
|
||||||
seq_lens: Optional[List[int]]
|
seq_lens: Optional[List[int]]
|
||||||
@ -83,15 +79,19 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
|||||||
# |-------------------- seq_len ----------------------|
|
# |-------------------- seq_len ----------------------|
|
||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
# Maximum query length in the batch.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int]
|
max_query_len: Optional[int]
|
||||||
# FIXME: It is for flash attn.
|
# FIXME: It is for flash attn.
|
||||||
# Maximum sequence length in the batch.
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
max_seq_len: Optional[int]
|
# requests only.
|
||||||
|
max_prefill_seq_len: int
|
||||||
|
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||||
|
# requests only.
|
||||||
|
max_decode_seq_len: int
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
# is [4, 6], it is [0, 4, 10].
|
# is [4, 6], it is [0, 4, 10].
|
||||||
subquery_start_loc: Optional[torch.Tensor]
|
query_start_loc: Optional[torch.Tensor]
|
||||||
# FIXME: It is for flash attn.
|
# FIXME: It is for flash attn.
|
||||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
@ -105,6 +105,8 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
|||||||
# Cuda-graph is currently enabled for decoding only.
|
# Cuda-graph is currently enabled for decoding only.
|
||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
use_cuda_graph: bool
|
use_cuda_graph: bool
|
||||||
|
_cached_prefill_metadata: Optional["XFormersMetadata"] = None
|
||||||
|
_cached_decode_metadata: Optional["XFormersMetadata"] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Set during the execution of the first attention op.
|
# Set during the execution of the first attention op.
|
||||||
@ -114,8 +116,68 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
|||||||
# will not appear in the __repr__ and __init__
|
# will not appear in the __repr__ and __init__
|
||||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
self.attn_bias: Optional[List[AttentionBias]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["XFormersMetadata"]:
|
||||||
|
if self.num_prefills == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
class XFormersImpl(AttentionImpl):
|
if self._cached_prefill_metadata is not None:
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
assert self.seq_lens is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
assert self.query_start_loc is not None
|
||||||
|
assert self.context_lens_tensor is not None
|
||||||
|
assert self.block_tables is not None
|
||||||
|
|
||||||
|
self._cached_prefill_metadata = XFormersMetadata(
|
||||||
|
num_prefills=self.num_prefills,
|
||||||
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||||
|
seq_lens=self.seq_lens[:self.num_prefills],
|
||||||
|
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||||
|
max_query_len=self.max_query_len,
|
||||||
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
|
max_decode_seq_len=0,
|
||||||
|
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
|
||||||
|
seq_start_loc=None,
|
||||||
|
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
|
||||||
|
block_tables=self.block_tables[:self.num_prefills],
|
||||||
|
use_cuda_graph=False,
|
||||||
|
)
|
||||||
|
return self._cached_prefill_metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["XFormersMetadata"]:
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self._cached_decode_metadata is not None:
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
|
||||||
|
self._cached_decode_metadata = XFormersMetadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
|
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||||
|
seq_lens=None,
|
||||||
|
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||||
|
max_query_len=None,
|
||||||
|
max_prefill_seq_len=0,
|
||||||
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
|
query_start_loc=None,
|
||||||
|
seq_start_loc=None,
|
||||||
|
context_lens_tensor=None,
|
||||||
|
block_tables=self.block_tables[self.num_prefills:],
|
||||||
|
use_cuda_graph=self.use_cuda_graph,
|
||||||
|
)
|
||||||
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
|
||||||
|
class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||||
"""
|
"""
|
||||||
If the input tensors contain prompt tokens, the layout is as follows:
|
If the input tensors contain prompt tokens, the layout is as follows:
|
||||||
|<--------------- num_prefill_tokens ----------------->|
|
|<--------------- num_prefill_tokens ----------------->|
|
||||||
@ -176,7 +238,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata[XFormersMetadata],
|
attn_metadata: "XFormersMetadata",
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
@ -244,7 +306,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
prefill_meta.block_tables,
|
prefill_meta.block_tables,
|
||||||
prefill_meta.subquery_start_loc,
|
prefill_meta.query_start_loc,
|
||||||
prefill_meta.seq_lens_tensor,
|
prefill_meta.seq_lens_tensor,
|
||||||
prefill_meta.context_lens_tensor,
|
prefill_meta.context_lens_tensor,
|
||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
@ -261,7 +323,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
decode_meta.block_tables,
|
decode_meta.block_tables,
|
||||||
decode_meta.seq_lens_tensor,
|
decode_meta.seq_lens_tensor,
|
||||||
decode_meta.max_seq_len,
|
decode_meta.max_decode_seq_len,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
|
|||||||
@ -4,8 +4,7 @@ from typing import List, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
AttentionMetadataPerStage)
|
|
||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
|
|
||||||
@ -57,7 +56,7 @@ class Attention(nn.Module):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
|
attn_metadata: AttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
kv_scale: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||||
|
|||||||
@ -16,8 +16,8 @@ class PagedAttentionMetadata:
|
|||||||
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
# (batch_size,). The length of sequences (entire tokens seen so far) per
|
||||||
# sequence.
|
# sequence.
|
||||||
seq_lens_tensor: Optional[torch.Tensor]
|
seq_lens_tensor: Optional[torch.Tensor]
|
||||||
# Maximum sequence length in the batch.
|
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
|
||||||
max_seq_len: Optional[int]
|
max_decode_seq_len: int
|
||||||
# (batch_size, max_blocks_per_seq).
|
# (batch_size, max_blocks_per_seq).
|
||||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||||
@ -166,7 +166,7 @@ class PagedAttention:
|
|||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
subquery_start_loc: torch.Tensor,
|
query_start_loc: torch.Tensor,
|
||||||
seq_lens_tensor: torch.Tensor,
|
seq_lens_tensor: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
context_lens: torch.Tensor,
|
||||||
max_query_len: int,
|
max_query_len: int,
|
||||||
@ -182,8 +182,8 @@ class PagedAttention:
|
|||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
# subquery_start_loc is (batch_size + 1,)
|
# query_start_loc is (batch_size + 1,)
|
||||||
subquery_start_loc[:-1],
|
query_start_loc[:-1],
|
||||||
seq_lens_tensor,
|
seq_lens_tensor,
|
||||||
context_lens,
|
context_lens,
|
||||||
max_query_len,
|
max_query_len,
|
||||||
|
|||||||
@ -618,6 +618,11 @@ class EngineArgs:
|
|||||||
decoding_config = DecodingConfig(
|
decoding_config = DecodingConfig(
|
||||||
guided_decoding_backend=self.guided_decoding_backend)
|
guided_decoding_backend=self.guided_decoding_backend)
|
||||||
|
|
||||||
|
if (model_config.get_sliding_window() is not None
|
||||||
|
and scheduler_config.chunked_prefill_enabled):
|
||||||
|
raise ValueError(
|
||||||
|
"Chunked prefill is not supported with sliding window.")
|
||||||
|
|
||||||
return EngineConfig(model_config=model_config,
|
return EngineConfig(model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
|
|||||||
@ -122,6 +122,7 @@ class RejectionSampler(nn.Module):
|
|||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
def _batch_modified_rejection_sampling(
|
def _batch_modified_rejection_sampling(
|
||||||
|
|||||||
@ -654,8 +654,9 @@ class SequenceGroupMetadata:
|
|||||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token_chunk_size(self) -> Optional[int]:
|
def token_chunk_size(self) -> int:
|
||||||
"""Return the number of tokens to be processed (chunk size)."""
|
"""Return the number of tokens to be processed (chunk size)."""
|
||||||
|
assert self._token_chunk_size is not None
|
||||||
return self._token_chunk_size
|
return self._token_chunk_size
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -293,21 +293,30 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
prompt_token_ids = seq_data.get_prompt_token_ids()
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||||
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
|
||||||
|
|
||||||
|
new_seq_data_dict = {
|
||||||
|
target_seq_id:
|
||||||
|
SequenceData(
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
output_token_ids=new_output_token_ids,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
# This is a hack. Technically, spec decoding should compute
|
||||||
|
# num_lookahead slots at one shot, but instead, it expands the batch
|
||||||
|
# and evaluate one by one right now. context_len is seq_len - 1 because
|
||||||
|
# the kv cache is filled by a previous batch in the batch expansion.
|
||||||
|
for data in new_seq_data_dict.values():
|
||||||
|
data.update_num_computed_tokens(data.get_len() - 1)
|
||||||
|
|
||||||
return SequenceGroupMetadata(
|
return SequenceGroupMetadata(
|
||||||
request_id=seq_group_metadata.request_id,
|
request_id=seq_group_metadata.request_id,
|
||||||
is_prompt=seq_group_metadata.is_prompt,
|
is_prompt=seq_group_metadata.is_prompt,
|
||||||
seq_data={
|
seq_data=new_seq_data_dict,
|
||||||
target_seq_id:
|
|
||||||
SequenceData(
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
output_token_ids=new_output_token_ids,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
sampling_params=seq_group_metadata.sampling_params,
|
sampling_params=seq_group_metadata.sampling_params,
|
||||||
block_tables={
|
block_tables={
|
||||||
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||||
},
|
},
|
||||||
lora_request=None,
|
lora_request=None,
|
||||||
|
token_chunk_size=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _split_scoring_output(
|
def _split_scoring_output(
|
||||||
|
|||||||
@ -114,6 +114,7 @@ class MultiStepWorker(Worker):
|
|||||||
token_logprob = seq_output.logprobs[token_id]
|
token_logprob = seq_output.logprobs[token_id]
|
||||||
|
|
||||||
seq.append_token_id(token_id, token_logprob.logprob)
|
seq.append_token_id(token_id, token_logprob.logprob)
|
||||||
|
seq.update_num_computed_tokens(1)
|
||||||
|
|
||||||
def _shallow_copy_inputs(
|
def _shallow_copy_inputs(
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
|
|||||||
@ -159,12 +159,10 @@ class CPUModelRunner:
|
|||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=None,
|
seq_lens_tensor=None,
|
||||||
max_seq_len=None,
|
max_decode_seq_len=None,
|
||||||
num_prefills=len(seq_lens),
|
num_prefills=len(seq_lens),
|
||||||
num_prefill_tokens=num_prompt_tokens,
|
num_prefill_tokens=num_prompt_tokens,
|
||||||
num_decode_tokens=0,
|
num_decode_tokens=0,
|
||||||
prefill_metadata=None,
|
|
||||||
decode_metadata=None,
|
|
||||||
block_tables=torch.tensor([]),
|
block_tables=torch.tensor([]),
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
)
|
)
|
||||||
@ -213,7 +211,7 @@ class CPUModelRunner:
|
|||||||
block_table = block_table[-sliding_window_blocks:]
|
block_table = block_table[-sliding_window_blocks:]
|
||||||
block_tables.append(block_table)
|
block_tables.append(block_table)
|
||||||
|
|
||||||
max_seq_len = max(seq_lens)
|
max_decode_seq_len = max(seq_lens)
|
||||||
|
|
||||||
input_tokens = torch.tensor(input_tokens,
|
input_tokens = torch.tensor(input_tokens,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
@ -243,12 +241,10 @@ class CPUModelRunner:
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_seq_len=max_seq_len,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
num_decode_tokens=len(input_tokens),
|
num_decode_tokens=len(input_tokens),
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
prefill_metadata=None,
|
|
||||||
decode_metadata=None,
|
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
||||||
from vllm.worker.model_runner import BatchType, ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -88,85 +88,24 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
||||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
prefill_reqs = []
|
|
||||||
decode_reqs = []
|
|
||||||
for seq_group_meta in seq_group_metadata_list:
|
|
||||||
if seq_group_meta.is_prompt:
|
|
||||||
prefill_reqs.append(seq_group_meta)
|
|
||||||
else:
|
|
||||||
decode_reqs.append(seq_group_meta)
|
|
||||||
|
|
||||||
# Prepare input tensors.
|
# Prepare input tensors.
|
||||||
(
|
(
|
||||||
input_tokens,
|
input_tokens,
|
||||||
input_positions,
|
input_positions,
|
||||||
prefill_attn_metadata,
|
attn_metadata,
|
||||||
prompt_lens,
|
seq_lens,
|
||||||
subquery_lens,
|
_,
|
||||||
lora_index_mapping,
|
lora_mapping,
|
||||||
lora_prompt_mapping,
|
|
||||||
lora_requests,
|
lora_requests,
|
||||||
multi_modal_input,
|
multi_modal_input,
|
||||||
slot_mapping,
|
slot_mapping,
|
||||||
) = self._prepare_prompt(prefill_reqs)
|
num_prefill_tokens,
|
||||||
(
|
num_decode_tokens,
|
||||||
decode_input_tokens,
|
num_prefills,
|
||||||
decode_input_positions,
|
) = self._prepare_model_input(seq_group_metadata_list)
|
||||||
decode_attn_metadata,
|
|
||||||
decode_lora_index_mapping,
|
|
||||||
decode_lora_prompt_mapping,
|
|
||||||
decode_lora_requests,
|
|
||||||
decode_slot_mapping,
|
|
||||||
) = self._prepare_decode(decode_reqs)
|
|
||||||
|
|
||||||
# Prepare PoolingMetadata
|
# Prepare PoolingMetadata
|
||||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||||
prompt_lens)
|
seq_lens)
|
||||||
|
|
||||||
if not self.scheduler_config.chunked_prefill_enabled:
|
|
||||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
|
||||||
|
|
||||||
num_prefills = len(prompt_lens)
|
|
||||||
num_prefill_tokens = len(input_tokens)
|
|
||||||
num_decode_tokens = len(decode_input_tokens)
|
|
||||||
|
|
||||||
# Coalesce tensors. Note that attn_metadata is currently not
|
|
||||||
# coalesced for simplicity.
|
|
||||||
input_tokens.extend(decode_input_tokens)
|
|
||||||
input_positions.extend(decode_input_positions)
|
|
||||||
slot_mapping.extend(decode_slot_mapping)
|
|
||||||
lora_index_mapping.extend(decode_lora_index_mapping)
|
|
||||||
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
|
||||||
lora_requests.update(decode_lora_requests)
|
|
||||||
|
|
||||||
input_tokens = torch.tensor(input_tokens,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device)
|
|
||||||
input_positions = torch.tensor(input_positions,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device)
|
|
||||||
slot_mapping = torch.tensor(slot_mapping,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
if self.lora_config:
|
|
||||||
lora_mapping = LoRAMapping(
|
|
||||||
lora_index_mapping,
|
|
||||||
lora_prompt_mapping,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
lora_mapping = None
|
|
||||||
|
|
||||||
# Broadcast the metadata.
|
|
||||||
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
|
||||||
# If it only contains 1 type, it triggers a single broadcast.
|
|
||||||
if (prefill_attn_metadata is not None
|
|
||||||
and decode_attn_metadata is not None):
|
|
||||||
batch_type = BatchType.MIXED
|
|
||||||
elif prefill_attn_metadata is not None:
|
|
||||||
batch_type = BatchType.PREFILL
|
|
||||||
else:
|
|
||||||
batch_type = BatchType.DECODE
|
|
||||||
|
|
||||||
metadata_dict = {
|
metadata_dict = {
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": input_tokens,
|
||||||
@ -178,65 +117,26 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
"num_decode_tokens": num_decode_tokens,
|
"num_decode_tokens": num_decode_tokens,
|
||||||
"slot_mapping": slot_mapping,
|
"slot_mapping": slot_mapping,
|
||||||
"num_prefills": num_prefills,
|
"num_prefills": num_prefills,
|
||||||
"batch_type": batch_type,
|
|
||||||
}
|
}
|
||||||
if prefill_attn_metadata is not None:
|
if attn_metadata:
|
||||||
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||||
else:
|
|
||||||
assert decode_attn_metadata is not None
|
|
||||||
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
|
|
||||||
# Broadcast decode attn metadata for mixed batch type.
|
|
||||||
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
|
||||||
# We can potentially reduce the overhead by coelescing tensors.
|
|
||||||
if batch_type == BatchType.MIXED:
|
|
||||||
assert decode_attn_metadata is not None
|
|
||||||
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
|
||||||
else:
|
else:
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
metadata_dict = broadcast_tensor_dict(src=0)
|
||||||
input_tokens = metadata_dict.pop("input_tokens")
|
input_tokens = metadata_dict.pop("input_tokens")
|
||||||
input_positions = metadata_dict.pop("input_positions")
|
input_positions = metadata_dict.pop("input_positions")
|
||||||
slot_mapping = metadata_dict.pop("slot_mapping")
|
|
||||||
num_prefills = metadata_dict.pop("num_prefills")
|
|
||||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||||
lora_requests = metadata_dict.pop("lora_requests")
|
lora_requests = metadata_dict.pop("lora_requests")
|
||||||
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||||
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
if metadata_dict:
|
||||||
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
batch_type = metadata_dict.pop("batch_type")
|
|
||||||
|
|
||||||
# Create an attention metadata.
|
|
||||||
prefill_attn_metadata = None
|
|
||||||
decode_attn_metadata = None
|
|
||||||
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
|
||||||
prefill_attn_metadata = self.attn_backend.make_metadata(
|
|
||||||
**metadata_dict)
|
**metadata_dict)
|
||||||
else:
|
else:
|
||||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = None
|
||||||
**metadata_dict)
|
|
||||||
|
|
||||||
pooling_metadata = PoolingMetadata(seq_groups=None,
|
pooling_metadata = PoolingMetadata(seq_groups=None,
|
||||||
seq_data=None,
|
seq_data=None,
|
||||||
prompt_lens=None)
|
prompt_lens=None)
|
||||||
|
|
||||||
# if it is a mixed batch, decode attn_metadata is broadcasted
|
|
||||||
# separately.
|
|
||||||
if batch_type == BatchType.MIXED:
|
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
|
||||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
|
||||||
**metadata_dict)
|
|
||||||
|
|
||||||
attn_metadata = AttentionMetadata(
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
prefill_metadata=prefill_attn_metadata,
|
|
||||||
decode_metadata=decode_attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||||
lora_requests, lora_mapping, multi_modal_input)
|
lora_requests, lora_mapping, multi_modal_input)
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,11 @@
|
|||||||
import time
|
import time
|
||||||
from enum import IntEnum
|
|
||||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
get_attn_backend)
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
VisionLanguageConfig)
|
VisionLanguageConfig)
|
||||||
@ -37,66 +35,38 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PreparePromptMetadata(NamedTuple):
|
class ModelInput(NamedTuple):
|
||||||
input_tokens: List[int]
|
input_tokens: torch.Tensor
|
||||||
input_positions: List[int]
|
input_positions: torch.Tensor
|
||||||
attn_metadata: Optional[AttentionMetadataPerStage]
|
attn_metadata: Optional[AttentionMetadata]
|
||||||
seq_lens: List[int]
|
seq_lens: List[int]
|
||||||
query_lens: List[int]
|
query_lens: List[int]
|
||||||
lora_index_mapping: List[int]
|
lora_mapping: Optional[LoRAMapping]
|
||||||
lora_prompt_mapping: List[int]
|
|
||||||
lora_requests: Set[LoRARequest]
|
lora_requests: Set[LoRARequest]
|
||||||
multi_modal_input: Optional[torch.Tensor]
|
multi_modal_input: Optional[torch.Tensor]
|
||||||
slot_mapping: List[int]
|
slot_mapping: torch.Tensor
|
||||||
|
num_prefill_tokens: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
num_prefills: int
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls):
|
def empty(cls, device):
|
||||||
return PreparePromptMetadata(
|
return ModelInput(
|
||||||
input_tokens=[],
|
input_tokens=torch.empty(0, device=device),
|
||||||
input_positions=[],
|
input_positions=torch.empty(0, device=device),
|
||||||
attn_metadata=None,
|
attn_metadata=None,
|
||||||
seq_lens=[],
|
seq_lens=[],
|
||||||
query_lens=[],
|
query_lens=[],
|
||||||
lora_index_mapping=[],
|
lora_mapping=None,
|
||||||
lora_prompt_mapping=[],
|
|
||||||
lora_requests=set(),
|
lora_requests=set(),
|
||||||
multi_modal_input=None,
|
multi_modal_input=None,
|
||||||
slot_mapping=[],
|
slot_mapping=torch.empty(0, device=device),
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
num_prefills=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PrepareDecodeMetadata(NamedTuple):
|
|
||||||
input_tokens: List[int]
|
|
||||||
input_positions: List[int]
|
|
||||||
attn_metadata: Optional[AttentionMetadata]
|
|
||||||
lora_index_mapping: List[int]
|
|
||||||
lora_prompt_mapping: List[int]
|
|
||||||
lora_requests: Set[LoRARequest]
|
|
||||||
slot_mapping: List[int]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def empty(cls):
|
|
||||||
return PrepareDecodeMetadata(
|
|
||||||
input_tokens=[],
|
|
||||||
input_positions=[],
|
|
||||||
attn_metadata=None,
|
|
||||||
lora_index_mapping=[],
|
|
||||||
lora_prompt_mapping=[],
|
|
||||||
lora_requests=set(),
|
|
||||||
slot_mapping=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# How batches are constructed.
|
|
||||||
class BatchType(IntEnum):
|
|
||||||
# Every batch is prefill.
|
|
||||||
PREFILL = 0
|
|
||||||
# Every batch is decode.
|
|
||||||
DECODE = 1
|
|
||||||
# Batch is a mixture of prefill and decode.
|
|
||||||
MIXED = 2
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -216,10 +186,22 @@ class ModelRunner:
|
|||||||
block_size = self.block_size
|
block_size = self.block_size
|
||||||
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
||||||
|
|
||||||
def _prepare_prompt(
|
def _prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
) -> PreparePromptMetadata:
|
) -> ModelInput:
|
||||||
|
"""Prepare the model input based on a given sequence group.
|
||||||
|
|
||||||
|
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||||
|
|
||||||
|
The result tensors and data structure also batches input in prefill
|
||||||
|
-> decode order. For example,
|
||||||
|
|
||||||
|
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||||
|
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||||
|
|
||||||
|
If cuda graph is required, this API automatically pads inputs.
|
||||||
|
"""
|
||||||
input_tokens: List[int] = []
|
input_tokens: List[int] = []
|
||||||
input_positions: List[int] = []
|
input_positions: List[int] = []
|
||||||
slot_mapping: List[int] = []
|
slot_mapping: List[int] = []
|
||||||
@ -228,212 +210,16 @@ class ModelRunner:
|
|||||||
lora_requests: Set[LoRARequest] = set()
|
lora_requests: Set[LoRARequest] = set()
|
||||||
|
|
||||||
seq_lens: List[int] = []
|
seq_lens: List[int] = []
|
||||||
|
prefill_seq_lens: List[int] = []
|
||||||
|
decode_seq_lens: List[int] = []
|
||||||
context_lens: List[int] = []
|
context_lens: List[int] = []
|
||||||
query_lens: List[int] = []
|
query_lens: List[int] = []
|
||||||
prefix_block_tables: List[List[int]] = []
|
|
||||||
multi_modal_input_list: List[torch.Tensor] = []
|
|
||||||
|
|
||||||
if len(seq_group_metadata_list) == 0:
|
|
||||||
return PreparePromptMetadata.empty()
|
|
||||||
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
|
||||||
assert seq_group_metadata.is_prompt
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
|
||||||
assert len(seq_ids) == 1
|
|
||||||
seq_id = seq_ids[0]
|
|
||||||
|
|
||||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
|
||||||
if (self.scheduler_config is not None
|
|
||||||
and self.scheduler_config.chunked_prefill_enabled
|
|
||||||
and not (computed_block_nums is None
|
|
||||||
or computed_block_nums == [])):
|
|
||||||
raise RuntimeError(
|
|
||||||
"chunked prefill cannot be used with prefix caching "
|
|
||||||
"now.")
|
|
||||||
|
|
||||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
|
||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
||||||
context_len = seq_data.get_num_computed_tokens()
|
|
||||||
# We should use get_len here because in case of preemption
|
|
||||||
# it contains output tokens.
|
|
||||||
seq_len = min(seq_data.get_len(), context_len + token_chunk_size)
|
|
||||||
prompt_tokens = seq_data.get_token_ids()[context_len:seq_len]
|
|
||||||
seq_lens.append(seq_len)
|
|
||||||
|
|
||||||
# NOTE: This only works for oooooooxxx style attention.
|
|
||||||
if computed_block_nums is not None and len(
|
|
||||||
computed_block_nums) > 0 and self.sliding_window is None:
|
|
||||||
# Prefix is not supported with sliding_window
|
|
||||||
context_len = len(computed_block_nums) * self.block_size
|
|
||||||
prompt_tokens = prompt_tokens[context_len:]
|
|
||||||
prefix_block_tables.append(computed_block_nums)
|
|
||||||
elif self.scheduler_config.chunked_prefill_enabled:
|
|
||||||
if seq_group_metadata.block_tables is not None:
|
|
||||||
# Prefill has chunked before.
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
|
||||||
prefix_block_tables.append(block_table)
|
|
||||||
else:
|
|
||||||
# The first prefill.
|
|
||||||
prefix_block_tables.append([])
|
|
||||||
else:
|
|
||||||
prefix_block_tables.append([])
|
|
||||||
# Right now, prefill start is always 0. However, this
|
|
||||||
# assumption can be changed once chunked prefill is introduced.
|
|
||||||
assert context_len == 0
|
|
||||||
|
|
||||||
# actual prompt lens
|
|
||||||
context_lens.append(context_len)
|
|
||||||
query_lens.append(seq_len - context_len)
|
|
||||||
|
|
||||||
input_tokens.extend(prompt_tokens)
|
|
||||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
|
||||||
# is always the first token in the sequence.
|
|
||||||
input_positions.extend(list(range(context_len, seq_len)))
|
|
||||||
lora_id = seq_group_metadata.lora_int_id
|
|
||||||
|
|
||||||
if lora_id > 0:
|
|
||||||
lora_requests.add(seq_group_metadata.lora_request)
|
|
||||||
|
|
||||||
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
|
||||||
lora_prompt_mapping.extend([lora_id] * (
|
|
||||||
seq_len - context_len if seq_group_metadata.sampling_params
|
|
||||||
and seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
|
||||||
|
|
||||||
if seq_group_metadata.multi_modal_data:
|
|
||||||
multi_modal_input_list.append(
|
|
||||||
seq_group_metadata.multi_modal_data.data)
|
|
||||||
|
|
||||||
if _is_block_tables_empty(seq_group_metadata.block_tables):
|
|
||||||
# During memory profiling, the block tables are not initialized
|
|
||||||
# yet. In this case, we just use a dummy slot mapping.
|
|
||||||
# In embeddings, the block tables are {seq_id: None}.
|
|
||||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute the slot mapping.
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
|
||||||
|
|
||||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
|
||||||
# where start_idx is max(0, seq_len - sliding_window).
|
|
||||||
# For example, if the prompt len is 10, sliding window is 8, and
|
|
||||||
# block size is 4, the first two tokens are masked and the slot
|
|
||||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
|
||||||
start_idx = 0
|
|
||||||
if self.sliding_window is not None:
|
|
||||||
assert context_len == 0, (
|
|
||||||
"Prefix caching is currently not supported with "
|
|
||||||
"sliding window attention")
|
|
||||||
start_idx = max(0, seq_len - self.sliding_window)
|
|
||||||
|
|
||||||
for i in range(context_len, seq_len):
|
|
||||||
if i < start_idx:
|
|
||||||
slot_mapping.append(_PAD_SLOT_ID)
|
|
||||||
continue
|
|
||||||
|
|
||||||
block_number = block_table[i // self.block_size]
|
|
||||||
block_offset = i % self.block_size
|
|
||||||
slot = block_number * self.block_size + block_offset
|
|
||||||
slot_mapping.append(slot)
|
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
|
||||||
max_seq_len = max(seq_lens)
|
|
||||||
assert max_query_len > 0
|
|
||||||
|
|
||||||
context_lens_tensor = torch.tensor(context_lens,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
if multi_modal_input_list:
|
|
||||||
assert self.vision_language_config, (
|
|
||||||
"Multi-modal inputs are only supported by "
|
|
||||||
"vision language models.")
|
|
||||||
multi_modal_input = torch.cat(multi_modal_input_list,
|
|
||||||
dim=0).to(self.device)
|
|
||||||
else:
|
|
||||||
multi_modal_input = None
|
|
||||||
|
|
||||||
# Prepare prefix block tables
|
|
||||||
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
|
||||||
block_tables = make_tensor_with_pad(
|
|
||||||
prefix_block_tables,
|
|
||||||
max_len=max_prompt_block_table_len,
|
|
||||||
pad=0,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query length can be shorter than key (i.e., prompt) when prefill
|
|
||||||
# is chunked or prefix cached.
|
|
||||||
query_lens_tensor = torch.tensor(query_lens,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device)
|
|
||||||
subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
seq_lens_tensor = torch.tensor(seq_lens,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=self.device)
|
|
||||||
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
torch.cumsum(query_lens_tensor,
|
|
||||||
dim=0,
|
|
||||||
dtype=subquery_start_loc.dtype,
|
|
||||||
out=subquery_start_loc[1:])
|
|
||||||
|
|
||||||
torch.cumsum(seq_lens_tensor,
|
|
||||||
dim=0,
|
|
||||||
dtype=seq_start_loc.dtype,
|
|
||||||
out=seq_start_loc[1:])
|
|
||||||
|
|
||||||
if self.attn_backend.get_name() == "flashinfer":
|
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
|
||||||
is_prompt=True,
|
|
||||||
use_cuda_graph=False,
|
|
||||||
seq_start_loc=seq_start_loc,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
block_tables=block_tables)
|
|
||||||
else:
|
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
|
||||||
is_prompt=True,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
subquery_start_loc=subquery_start_loc,
|
|
||||||
seq_start_loc=seq_start_loc,
|
|
||||||
context_lens_tensor=context_lens_tensor,
|
|
||||||
block_tables=block_tables,
|
|
||||||
use_cuda_graph=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return PreparePromptMetadata(
|
|
||||||
input_tokens=input_tokens,
|
|
||||||
input_positions=input_positions,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
seq_lens=seq_lens,
|
|
||||||
query_lens=query_lens,
|
|
||||||
lora_index_mapping=lora_index_mapping,
|
|
||||||
lora_prompt_mapping=lora_prompt_mapping,
|
|
||||||
lora_requests=lora_requests,
|
|
||||||
multi_modal_input=multi_modal_input,
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _prepare_decode(
|
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
) -> PrepareDecodeMetadata:
|
|
||||||
input_tokens: List[int] = []
|
|
||||||
input_positions: List[int] = []
|
|
||||||
slot_mapping: List[int] = []
|
|
||||||
seq_lens: List[int] = []
|
|
||||||
block_tables: List[List[int]] = []
|
block_tables: List[List[int]] = []
|
||||||
lora_index_mapping: List[int] = []
|
multi_modal_input_list: List[torch.Tensor] = []
|
||||||
lora_prompt_mapping: List[int] = []
|
decode_only = True
|
||||||
lora_requests: Set[LoRARequest] = set()
|
num_prefills = 0
|
||||||
|
num_prefill_tokens = 0
|
||||||
|
num_decode_tokens = 0
|
||||||
|
|
||||||
# The following fields are only for flashinfer
|
# The following fields are only for flashinfer
|
||||||
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
|
||||||
@ -454,60 +240,186 @@ class ModelRunner:
|
|||||||
paged_kv_last_page_len: List[int] = []
|
paged_kv_last_page_len: List[int] = []
|
||||||
|
|
||||||
if len(seq_group_metadata_list) == 0:
|
if len(seq_group_metadata_list) == 0:
|
||||||
return PrepareDecodeMetadata.empty()
|
return ModelInput.empty(self.device)
|
||||||
|
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
assert not seq_group_metadata.is_prompt
|
|
||||||
assert seq_group_metadata.token_chunk_size == 1
|
|
||||||
|
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
lora_id = seq_group_metadata.lora_int_id
|
is_prompt = seq_group_metadata.is_prompt
|
||||||
|
|
||||||
if lora_id > 0:
|
|
||||||
lora_requests.add(seq_group_metadata.lora_request)
|
|
||||||
|
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
|
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||||
|
if (self.scheduler_config is not None
|
||||||
|
and self.scheduler_config.chunked_prefill_enabled
|
||||||
|
and not (computed_block_nums is None
|
||||||
|
or computed_block_nums == [])):
|
||||||
|
raise RuntimeError(
|
||||||
|
"chunked prefill cannot be used with prefix caching "
|
||||||
|
"now.")
|
||||||
|
|
||||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
generation_token = seq_data.get_last_token_id()
|
if is_prompt:
|
||||||
input_tokens.append(generation_token)
|
context_len = seq_data.get_num_computed_tokens()
|
||||||
|
else:
|
||||||
|
# get_num_computed_tokens is incorrect for spec decoding.
|
||||||
|
# So, we should have a special logic here.
|
||||||
|
# TODO(sang): Fix it.
|
||||||
|
context_len = seq_data.get_len() - 1
|
||||||
|
|
||||||
seq_len = seq_data.get_len()
|
seq_len = min(
|
||||||
position = seq_len - 1
|
seq_data.get_len(),
|
||||||
input_positions.append(position)
|
context_len + seq_group_metadata.token_chunk_size)
|
||||||
|
if is_prompt:
|
||||||
|
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||||
|
else:
|
||||||
|
# Optimization. get_token_ids requires the entire copy of
|
||||||
|
# tokens.
|
||||||
|
tokens = [seq_data.get_last_token_id()]
|
||||||
|
|
||||||
seq_len = seq_len if self.sliding_window is None else min(
|
# Prefix cache was hit.
|
||||||
seq_len, self.sliding_window)
|
# Prefix is not supported with sliding_window
|
||||||
seq_lens.append(seq_len)
|
prefix_cache_hit = (computed_block_nums is not None
|
||||||
|
and len(computed_block_nums) > 0
|
||||||
|
and self.sliding_window is None
|
||||||
|
and is_prompt)
|
||||||
|
|
||||||
block_table = seq_group_metadata.block_tables[seq_id]
|
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||||
block_number = block_table[position // self.block_size]
|
# only allowing multiple of block_size chunk size.
|
||||||
block_offset = position % self.block_size
|
# NOTE: This only works for oooooooxxx style attention.
|
||||||
slot = block_number * self.block_size + block_offset
|
if prefix_cache_hit:
|
||||||
slot_mapping.append(slot)
|
assert computed_block_nums is not None
|
||||||
lora_index_mapping.append(lora_id)
|
context_len = len(computed_block_nums) * self.block_size
|
||||||
lora_prompt_mapping.append(lora_id)
|
tokens = tokens[context_len:]
|
||||||
|
if self.attn_backend.get_name() == "flash-attn":
|
||||||
|
# NOTE(woosuk): For flash-attn, the block table should
|
||||||
|
# include the entries for the incoming prefill tokens.
|
||||||
|
# TODO(woosuk): This is a temporary fix. We should
|
||||||
|
# provide a unified interface for different backends.
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
else:
|
||||||
|
block_table = computed_block_nums
|
||||||
|
elif (self.scheduler_config.chunked_prefill_enabled
|
||||||
|
or not is_prompt):
|
||||||
|
if seq_group_metadata.block_tables is not None:
|
||||||
|
# chunked prefill or decode
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
# chunked prefill doesn't support sliding window.
|
||||||
|
assert (not self.scheduler_config.
|
||||||
|
chunked_prefill_enabled)
|
||||||
|
sliding_window_blocks = (self.sliding_window //
|
||||||
|
self.block_size)
|
||||||
|
block_table = block_table[-sliding_window_blocks:]
|
||||||
|
|
||||||
if self.sliding_window is not None:
|
if self.attn_backend.get_name() == "flashinfer":
|
||||||
sliding_window_blocks = (self.sliding_window //
|
paged_kv_indices.extend(block_table)
|
||||||
self.block_size)
|
paged_kv_indptr.append(paged_kv_indptr[-1] +
|
||||||
block_table = block_table[-sliding_window_blocks:]
|
len(block_table))
|
||||||
|
last_page_len = seq_data.get_len(
|
||||||
|
) % self.block_size
|
||||||
|
if last_page_len == 0:
|
||||||
|
last_page_len = self.block_size
|
||||||
|
paged_kv_last_page_len.append(last_page_len)
|
||||||
|
else:
|
||||||
|
# Only happens when memory profiling runs.
|
||||||
|
block_table = []
|
||||||
|
else:
|
||||||
|
# Prefill without chunked prefill or memory profiling.
|
||||||
|
block_table = []
|
||||||
block_tables.append(block_table)
|
block_tables.append(block_table)
|
||||||
|
|
||||||
paged_kv_indices.extend(block_table)
|
# TODO(sang): This is a hack to make sliding window work with
|
||||||
paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table))
|
# paged attn. We can remove it if we make paged attn kernel
|
||||||
last_page_len = seq_data.get_len() % self.block_size
|
# to properly handle slinding window attn.
|
||||||
if last_page_len == 0:
|
if (self.sliding_window is not None and not is_prompt):
|
||||||
last_page_len = self.block_size
|
seq_len = min(seq_len, self.sliding_window)
|
||||||
paged_kv_last_page_len.append(last_page_len)
|
context_len = seq_len - 1
|
||||||
|
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
context_lens.append(context_len)
|
||||||
|
query_len = seq_len - context_len
|
||||||
|
query_lens.append(query_len)
|
||||||
|
input_tokens.extend(tokens)
|
||||||
|
input_positions.extend(list(range(context_len, seq_len)))
|
||||||
|
lora_id = seq_group_metadata.lora_int_id
|
||||||
|
|
||||||
|
if is_prompt:
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
num_prefills += 1
|
||||||
|
num_prefill_tokens += len(tokens)
|
||||||
|
decode_only = False
|
||||||
|
prefill_seq_lens.append(seq_len)
|
||||||
|
else:
|
||||||
|
assert query_len == 1, (
|
||||||
|
"seq_len: {}, context_len: {}, query_len: {}".format(
|
||||||
|
seq_len, context_len, query_len))
|
||||||
|
num_decode_tokens += query_len
|
||||||
|
decode_seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
if lora_id > 0:
|
||||||
|
lora_requests.add(seq_group_metadata.lora_request)
|
||||||
|
|
||||||
|
lora_index_mapping += [lora_id] * (seq_len - context_len)
|
||||||
|
lora_prompt_mapping.extend(
|
||||||
|
[lora_id] *
|
||||||
|
(seq_len -
|
||||||
|
context_len if seq_group_metadata.sampling_params
|
||||||
|
and seq_group_metadata.sampling_params.prompt_logprobs
|
||||||
|
else 1))
|
||||||
|
|
||||||
|
if seq_group_metadata.multi_modal_data:
|
||||||
|
multi_modal_input_list.append(
|
||||||
|
seq_group_metadata.multi_modal_data.data)
|
||||||
|
|
||||||
|
if _is_block_tables_empty(seq_group_metadata.block_tables):
|
||||||
|
# During memory profiling, the block tables are not
|
||||||
|
# initialized yet. In this case, we just use a dummy
|
||||||
|
# slot mapping.
|
||||||
|
# In embeddings, the block tables are {seq_id: None}.
|
||||||
|
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compute the slot mapping.
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
|
||||||
|
# Mask the [0, start_idx) tokens of the prompt with
|
||||||
|
# _PAD_SLOT_ID, where start_idx is max(0, seq_len -
|
||||||
|
# sliding_window). For example, if the prompt len is 10,
|
||||||
|
# sliding window is 8, and block size is 4, the first two
|
||||||
|
# tokens are masked and the slot mapping will be
|
||||||
|
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||||
|
start_idx = 0
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
if is_prompt:
|
||||||
|
assert context_len == 0, (
|
||||||
|
"Prefix caching is currently not supported with "
|
||||||
|
"sliding window attention")
|
||||||
|
# It is an optimization. When it is decoding, it is always
|
||||||
|
# 0. When prefill, we use it to not write slots to kv cache
|
||||||
|
# to save memory.
|
||||||
|
start_idx = max(0, query_len - self.sliding_window)
|
||||||
|
|
||||||
|
for i in range(context_len, seq_len):
|
||||||
|
if i < start_idx:
|
||||||
|
slot_mapping.append(_PAD_SLOT_ID)
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_number = block_table[i // self.block_size]
|
||||||
|
block_offset = i % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
# vLLM uses cuda graph only for decoding requests.
|
|
||||||
# See `capture_model` API for more details.
|
|
||||||
# For decoding requests, batch_size == input_tokens.
|
|
||||||
batch_size = len(input_tokens)
|
batch_size = len(input_tokens)
|
||||||
max_seq_len = max(seq_lens)
|
max_query_len = max(query_lens)
|
||||||
use_captured_graph = (not self.model_config.enforce_eager
|
max_prefill_seq_len = max(prefill_seq_lens, default=0)
|
||||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
max_decode_seq_len = max(decode_seq_lens, default=0)
|
||||||
and max_seq_len <= self.max_seq_len_to_capture)
|
|
||||||
|
# If cuda graph can be used, pad tensors accordingly.
|
||||||
|
# See `capture_model` API for more details.
|
||||||
|
# vLLM uses cuda graph only for decoding requests.
|
||||||
|
use_captured_graph = (
|
||||||
|
decode_only and not self.model_config.enforce_eager
|
||||||
|
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||||
|
and max_decode_seq_len <= self.max_seq_len_to_capture)
|
||||||
if use_captured_graph:
|
if use_captured_graph:
|
||||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||||
assert graph_batch_size >= batch_size
|
assert graph_batch_size >= batch_size
|
||||||
@ -519,18 +431,9 @@ class ModelRunner:
|
|||||||
block_tables.append([])
|
block_tables.append([])
|
||||||
lora_index_mapping.append(0)
|
lora_index_mapping.append(0)
|
||||||
batch_size = graph_batch_size
|
batch_size = graph_batch_size
|
||||||
|
num_decode_tokens = batch_size
|
||||||
seq_lens_tensor = torch.tensor(seq_lens,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
if use_captured_graph:
|
if use_captured_graph:
|
||||||
# When using cuda-graph all these tensors should be
|
|
||||||
# padded.
|
|
||||||
assert seq_lens_tensor.shape[0] == len(input_tokens)
|
|
||||||
assert seq_lens_tensor.shape[0] == len(input_positions)
|
|
||||||
assert seq_lens_tensor.shape[0] == len(slot_mapping)
|
|
||||||
|
|
||||||
# The shape of graph_block_tables is
|
# The shape of graph_block_tables is
|
||||||
# [max batch size, max context len // block size].
|
# [max batch size, max context len // block size].
|
||||||
input_block_tables = self.graph_block_tables[:batch_size]
|
input_block_tables = self.graph_block_tables[:batch_size]
|
||||||
@ -548,6 +451,57 @@ class ModelRunner:
|
|||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||||
|
|
||||||
|
context_lens_tensor = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
if multi_modal_input_list:
|
||||||
|
assert self.vision_language_config, (
|
||||||
|
"Multi-modal inputs are only supported by "
|
||||||
|
"vision language models.")
|
||||||
|
multi_modal_input = torch.cat(multi_modal_input_list,
|
||||||
|
dim=0).to(self.device)
|
||||||
|
else:
|
||||||
|
multi_modal_input = None
|
||||||
|
|
||||||
|
seq_lens_tensor = torch.tensor(seq_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device)
|
||||||
|
query_lens_tensor = torch.tensor(query_lens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
seq_lens_tensor = torch.tensor(seq_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device)
|
||||||
|
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
torch.cumsum(query_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=query_start_loc.dtype,
|
||||||
|
out=query_start_loc[1:])
|
||||||
|
|
||||||
|
torch.cumsum(seq_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=seq_start_loc.dtype,
|
||||||
|
out=seq_start_loc[1:])
|
||||||
|
|
||||||
|
input_tokens_tensor = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
input_positions_tensor = torch.tensor(input_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
slot_mapping_tensor = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
if self.attn_backend.get_name() == "flashinfer":
|
if self.attn_backend.get_name() == "flashinfer":
|
||||||
if not hasattr(self, "flashinfer_workspace_buffer"):
|
if not hasattr(self, "flashinfer_workspace_buffer"):
|
||||||
@ -555,53 +509,75 @@ class ModelRunner:
|
|||||||
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
|
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
|
||||||
self.flashinfer_workspace_buffer = torch.empty(
|
self.flashinfer_workspace_buffer = torch.empty(
|
||||||
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
|
16 * 1024 * 1024, dtype=torch.uint8, device=self.device)
|
||||||
paged_kv_indptr = torch.tensor(paged_kv_indptr,
|
paged_kv_indptr_tensor = torch.tensor(paged_kv_indptr,
|
||||||
dtype=torch.int,
|
|
||||||
device=self.device)
|
|
||||||
paged_kv_indices = torch.tensor(paged_kv_indices,
|
|
||||||
dtype=torch.int,
|
|
||||||
device=self.device)
|
|
||||||
paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len,
|
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
paged_kv_indices_tensor = torch.tensor(paged_kv_indices,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device)
|
||||||
|
paged_kv_last_page_len_tensor = torch.tensor(
|
||||||
|
paged_kv_last_page_len, dtype=torch.int, device=self.device)
|
||||||
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
|
kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype,
|
||||||
self.model_config.dtype)
|
self.model_config.dtype)
|
||||||
|
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
is_prompt=False,
|
num_prefills=num_prefills,
|
||||||
|
slot_mapping=slot_mapping_tensor,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
use_cuda_graph=False,
|
use_cuda_graph=False,
|
||||||
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
|
block_tables=block_tables,
|
||||||
workspace_buffer=self.flashinfer_workspace_buffer,
|
workspace_buffer=self.flashinfer_workspace_buffer,
|
||||||
paged_kv_indptr=paged_kv_indptr,
|
paged_kv_indptr=paged_kv_indptr_tensor,
|
||||||
paged_kv_indices=paged_kv_indices,
|
paged_kv_indices=paged_kv_indices_tensor,
|
||||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
paged_kv_last_page_len=paged_kv_last_page_len_tensor,
|
||||||
num_qo_heads=self.model_config.get_num_attention_heads(
|
num_qo_heads=self.model_config.get_num_attention_heads(
|
||||||
self.parallel_config),
|
self.parallel_config),
|
||||||
num_kv_heads=self.model_config.get_num_kv_heads(
|
num_kv_heads=self.model_config.get_num_kv_heads(
|
||||||
self.parallel_config),
|
self.parallel_config),
|
||||||
head_dim=self.model_config.get_head_size(),
|
head_dim=self.model_config.get_head_size(),
|
||||||
page_size=self.block_size,
|
page_size=16,
|
||||||
|
seq_start_loc=seq_start_loc,
|
||||||
data_type=kv_cache_dtype)
|
data_type=kv_cache_dtype)
|
||||||
else:
|
else:
|
||||||
attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
is_prompt=False,
|
num_prefills=num_prefills,
|
||||||
seq_lens=None,
|
slot_mapping=slot_mapping_tensor,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=None,
|
max_query_len=max_query_len,
|
||||||
max_seq_len=max_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
subquery_start_loc=None,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
seq_start_loc=None,
|
query_start_loc=query_start_loc,
|
||||||
context_lens_tensor=None,
|
seq_start_loc=seq_start_loc,
|
||||||
|
context_lens_tensor=context_lens_tensor,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
)
|
)
|
||||||
return PrepareDecodeMetadata(
|
|
||||||
input_tokens=input_tokens,
|
if self.lora_config:
|
||||||
input_positions=input_positions,
|
lora_mapping = LoRAMapping(
|
||||||
|
lora_index_mapping,
|
||||||
|
lora_prompt_mapping,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_mapping = None
|
||||||
|
|
||||||
|
return ModelInput(
|
||||||
|
input_tokens=input_tokens_tensor,
|
||||||
|
input_positions=input_positions_tensor,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
lora_index_mapping=lora_index_mapping,
|
seq_lens=seq_lens,
|
||||||
lora_prompt_mapping=lora_prompt_mapping,
|
query_lens=query_lens,
|
||||||
|
lora_mapping=lora_mapping,
|
||||||
lora_requests=lora_requests,
|
lora_requests=lora_requests,
|
||||||
slot_mapping=slot_mapping,
|
multi_modal_input=multi_modal_input,
|
||||||
|
slot_mapping=slot_mapping_tensor,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
num_prefills=num_prefills,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_input_tensors(
|
def prepare_input_tensors(
|
||||||
@ -610,85 +586,25 @@ class ModelRunner:
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||||
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
Set[LoRARequest], LoRAMapping, torch.Tensor]:
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
prefill_reqs = []
|
|
||||||
decode_reqs = []
|
|
||||||
for seq_group_meta in seq_group_metadata_list:
|
|
||||||
if seq_group_meta.is_prompt:
|
|
||||||
prefill_reqs.append(seq_group_meta)
|
|
||||||
else:
|
|
||||||
decode_reqs.append(seq_group_meta)
|
|
||||||
|
|
||||||
# Prepare input tensors.
|
# Prepare input tensors.
|
||||||
(
|
(
|
||||||
input_tokens,
|
input_tokens,
|
||||||
input_positions,
|
input_positions,
|
||||||
prefill_attn_metadata,
|
attn_metadata,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
query_lens,
|
query_lens,
|
||||||
lora_index_mapping,
|
lora_mapping,
|
||||||
lora_prompt_mapping,
|
|
||||||
lora_requests,
|
lora_requests,
|
||||||
multi_modal_input,
|
multi_modal_input,
|
||||||
slot_mapping,
|
slot_mapping,
|
||||||
) = self._prepare_prompt(prefill_reqs)
|
num_prefill_tokens,
|
||||||
(
|
num_decode_tokens,
|
||||||
decode_input_tokens,
|
num_prefills,
|
||||||
decode_input_positions,
|
) = self._prepare_model_input(seq_group_metadata_list)
|
||||||
decode_attn_metadata,
|
|
||||||
decode_lora_index_mapping,
|
|
||||||
decode_lora_prompt_mapping,
|
|
||||||
decode_lora_requests,
|
|
||||||
decode_slot_mapping,
|
|
||||||
) = self._prepare_decode(decode_reqs)
|
|
||||||
sampling_metadata = SamplingMetadata.prepare(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||||
self.pin_memory)
|
self.pin_memory)
|
||||||
|
|
||||||
if not self.scheduler_config.chunked_prefill_enabled:
|
|
||||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
|
||||||
|
|
||||||
num_prefills = len(seq_lens)
|
|
||||||
num_prefill_tokens = len(input_tokens)
|
|
||||||
num_decode_tokens = len(decode_input_tokens)
|
|
||||||
|
|
||||||
# Coalesce tensors. Note that attn_metadata is currently not
|
|
||||||
# coalesced for simplicity.
|
|
||||||
input_tokens.extend(decode_input_tokens)
|
|
||||||
input_positions.extend(decode_input_positions)
|
|
||||||
slot_mapping.extend(decode_slot_mapping)
|
|
||||||
lora_index_mapping.extend(decode_lora_index_mapping)
|
|
||||||
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
|
||||||
lora_requests.update(decode_lora_requests)
|
|
||||||
|
|
||||||
input_tokens = torch.tensor(input_tokens,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device)
|
|
||||||
input_positions = torch.tensor(input_positions,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device)
|
|
||||||
slot_mapping = torch.tensor(slot_mapping,
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
if self.lora_config:
|
|
||||||
lora_mapping = LoRAMapping(
|
|
||||||
lora_index_mapping,
|
|
||||||
lora_prompt_mapping,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
lora_mapping = None
|
|
||||||
|
|
||||||
# Broadcast the metadata.
|
|
||||||
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
|
||||||
# If it only contains 1 type, it triggers a single broadcast.
|
|
||||||
if (prefill_attn_metadata is not None
|
|
||||||
and decode_attn_metadata is not None):
|
|
||||||
batch_type = BatchType.MIXED
|
|
||||||
elif prefill_attn_metadata is not None:
|
|
||||||
batch_type = BatchType.PREFILL
|
|
||||||
else:
|
|
||||||
batch_type = BatchType.DECODE
|
|
||||||
|
|
||||||
metadata_dict = {
|
metadata_dict = {
|
||||||
"input_tokens": input_tokens,
|
"input_tokens": input_tokens,
|
||||||
"input_positions": input_positions,
|
"input_positions": input_positions,
|
||||||
@ -701,46 +617,24 @@ class ModelRunner:
|
|||||||
"num_decode_tokens": num_decode_tokens,
|
"num_decode_tokens": num_decode_tokens,
|
||||||
"slot_mapping": slot_mapping,
|
"slot_mapping": slot_mapping,
|
||||||
"num_prefills": num_prefills,
|
"num_prefills": num_prefills,
|
||||||
"batch_type": batch_type,
|
|
||||||
}
|
}
|
||||||
if prefill_attn_metadata is not None:
|
if attn_metadata:
|
||||||
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||||
else:
|
|
||||||
assert decode_attn_metadata is not None
|
|
||||||
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
broadcast_tensor_dict(metadata_dict, src=0)
|
||||||
|
|
||||||
# Broadcast decode attn metadata for mixed batch type.
|
|
||||||
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
|
||||||
# We can potentially reduce the overhead by coelescing tensors.
|
|
||||||
if batch_type == BatchType.MIXED:
|
|
||||||
assert decode_attn_metadata is not None
|
|
||||||
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
|
||||||
broadcast_tensor_dict(metadata_dict, src=0)
|
|
||||||
else:
|
else:
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
metadata_dict = broadcast_tensor_dict(src=0)
|
||||||
input_tokens = metadata_dict.pop("input_tokens")
|
input_tokens = metadata_dict.pop("input_tokens")
|
||||||
input_positions = metadata_dict.pop("input_positions")
|
input_positions = metadata_dict.pop("input_positions")
|
||||||
slot_mapping = metadata_dict.pop("slot_mapping")
|
|
||||||
num_prefills = metadata_dict.pop("num_prefills")
|
|
||||||
selected_token_indices = metadata_dict.pop(
|
selected_token_indices = metadata_dict.pop(
|
||||||
"selected_token_indices")
|
"selected_token_indices")
|
||||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||||
lora_requests = metadata_dict.pop("lora_requests")
|
lora_requests = metadata_dict.pop("lora_requests")
|
||||||
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||||
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
if metadata_dict:
|
||||||
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
batch_type = metadata_dict.pop("batch_type")
|
|
||||||
|
|
||||||
# Create an attention metadata.
|
|
||||||
prefill_attn_metadata = None
|
|
||||||
decode_attn_metadata = None
|
|
||||||
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
|
||||||
prefill_attn_metadata = self.attn_backend.make_metadata(
|
|
||||||
**metadata_dict)
|
**metadata_dict)
|
||||||
else:
|
else:
|
||||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
attn_metadata = None
|
||||||
**metadata_dict)
|
|
||||||
sampling_metadata = SamplingMetadata(
|
sampling_metadata = SamplingMetadata(
|
||||||
seq_groups=None,
|
seq_groups=None,
|
||||||
selected_token_indices=selected_token_indices,
|
selected_token_indices=selected_token_indices,
|
||||||
@ -748,22 +642,6 @@ class ModelRunner:
|
|||||||
num_prompts=0,
|
num_prompts=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if it is a mixed batch, decode attn_metadata is broadcasted
|
|
||||||
# separately.
|
|
||||||
if batch_type == BatchType.MIXED:
|
|
||||||
metadata_dict = broadcast_tensor_dict(src=0)
|
|
||||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
|
||||||
**metadata_dict)
|
|
||||||
|
|
||||||
attn_metadata = AttentionMetadata(
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
slot_mapping=slot_mapping,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
prefill_metadata=prefill_attn_metadata,
|
|
||||||
decode_metadata=decode_attn_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (input_tokens, input_positions, attn_metadata,
|
return (input_tokens, input_positions, attn_metadata,
|
||||||
sampling_metadata, lora_requests, lora_mapping,
|
sampling_metadata, lora_requests, lora_mapping,
|
||||||
multi_modal_input)
|
multi_modal_input)
|
||||||
@ -954,25 +832,21 @@ class ModelRunner:
|
|||||||
# memory usage of CUDA graph.
|
# memory usage of CUDA graph.
|
||||||
for batch_size in reversed(batch_size_capture_list):
|
for batch_size in reversed(batch_size_capture_list):
|
||||||
# Create dummy attn_metadata.
|
# Create dummy attn_metadata.
|
||||||
decode_metadata = self.attn_backend.make_metadata(
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
is_prompt=False,
|
|
||||||
seq_lens=None,
|
|
||||||
seq_lens_tensor=seq_lens[:batch_size],
|
|
||||||
max_query_len=None,
|
|
||||||
max_seq_len=self.max_seq_len_to_capture,
|
|
||||||
subquery_start_loc=None,
|
|
||||||
seq_start_loc=None,
|
|
||||||
context_lens_tensor=None,
|
|
||||||
block_tables=block_tables[:batch_size],
|
|
||||||
use_cuda_graph=True,
|
|
||||||
)
|
|
||||||
attn_metadata = AttentionMetadata(
|
|
||||||
num_prefills=0,
|
num_prefills=0,
|
||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
num_decode_tokens=batch_size,
|
num_decode_tokens=batch_size,
|
||||||
slot_mapping=slot_mapping[:batch_size],
|
slot_mapping=slot_mapping[:batch_size],
|
||||||
prefill_metadata=None,
|
seq_lens=None,
|
||||||
decode_metadata=decode_metadata,
|
seq_lens_tensor=seq_lens[:batch_size],
|
||||||
|
max_query_len=None,
|
||||||
|
max_prefill_seq_len=0,
|
||||||
|
max_decode_seq_len=self.max_seq_len_to_capture,
|
||||||
|
query_start_loc=None,
|
||||||
|
seq_start_loc=None,
|
||||||
|
context_lens_tensor=None,
|
||||||
|
block_tables=block_tables[:batch_size],
|
||||||
|
use_cuda_graph=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user