[Core][5/N] Fully working chunked prefill e2e (#3884)

This commit is contained in:
SangBin Cho 2024-04-11 09:56:48 +09:00 committed by GitHub
parent 63e7176f26
commit 67b4221a61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 927 additions and 315 deletions

View File

@ -29,6 +29,8 @@ steps:
- pytest -v -s test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py

View File

@ -177,8 +177,7 @@ if __name__ == '__main__':
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(

View File

@ -74,25 +74,31 @@ def run_vllm(
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir)
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
# Add the requests to the engine.
for prompt, _, output_len in requests:
@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
args.output_len)
if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching,
args.gpu_memory_utilization, args.download_dir)
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -335,6 +341,14 @@ if __name__ == "__main__":
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,

View File

@ -0,0 +1,70 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
It tests chunked prefill. Chunked prefill can be enabled by
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`.
"""
import pytest
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
and not enforce_eager):
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
"for high TP to save testing time.")
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
print(vllm_outputs[0])
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@ -104,10 +104,10 @@ def test_chunk():
# One chunked prefill, and one decoding.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
# The first one is decoding.
assert seq_group_meta[0].token_chunk_size == 1
# The first one is prefill. Scheduler guarantees ordering.
assert seq_group_meta[0].token_chunk_size == 56
# The second one is a chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56
assert seq_group_meta[1].token_chunk_size == 1
assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 57
@ -157,12 +157,12 @@ def test_complex():
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 3
# The first one is decoding.
assert seq_group_meta[0].token_chunk_size == 1
# The second one is a chunked prefill.
# The first one is the first chunked prefill.
assert seq_group_meta[0].token_chunk_size == 7
# The second one is the second new chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56
# The third one is also chunked.
assert seq_group_meta[2].token_chunk_size == 7
# The last one is decode.
assert seq_group_meta[2].token_chunk_size == 1
# Two of them are in chunked prefill.
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64

View File

@ -33,11 +33,16 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2)
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

View File

@ -0,0 +1,66 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_chunked_prefill_distributed.py
```
"""
import os
import pytest
import torch
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
assert chunked_prefill_token_size != -1
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@ -141,7 +141,7 @@ def server(zephyr_lora_files):
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128"
"128",
])
ray.get(server_runner.ready.remote())
yield server_runner

View File

@ -12,7 +12,7 @@ MODELS = [
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/pythia-70m",
"bigscience/bloom-560m",
"bigscience/bloom-560m", # Testing alibi slopes.
"microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t",
# "allenai/OLMo-1B", # Broken

View File

@ -1,14 +1,18 @@
import pytest
import torch
from vllm.config import ModelConfig
from vllm.config import ModelConfig, SchedulerConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
model_runner = ModelRunner(None, None, None, None, None)
scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(None, None, scheduler_config, None, None)
model_runner.set_block_size(16)
prompt_lens = []
@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size):
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts.
device = model_runner.device
@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device))
assert attn_metadata.prompt_lens == prompt_lens
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
assert attn_metadata.num_generation_tokens == 0
assert attn_metadata.max_prompt_len == max(prompt_lens)
# Test subquery start locs.
@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
assert len(input_tokens) == sum(prompt_lens)
assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(input_tokens, input_positions)
assert input_tokens == input_positions
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size):
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(model_config, None, scheduler_config, None,
None)
model_runner.set_block_size(16)
prompt_lens = []
@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _ = (
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None
assert attn_metadata.num_prompt_tokens == 0
assert attn_metadata.num_generation_tokens == expected_bs
assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None
@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size):
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (expected_bs, )
assert input_positions.shape == (expected_bs, )
torch.testing.assert_close(input_tokens, input_positions)
assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
assert input_tokens == input_positions
# Verify Sampling
expected_selected_token_indices = []
@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size):
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
def test_empty_seq_group():
"""Verify prepare prompt and decode returns empty output."""
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_prompt_lens) == 0
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
def get_world_size(group=None):
return 1
def mock_get_process_group_ranks(group=None):
return [0]
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
mock_get_process_group_ranks)
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=enforce_eager,
)
scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=True)
model_runner = ModelRunner(model_config,
None,
scheduler_config,
None,
None,
is_driver_worker=True)
model_runner.set_block_size(16)
# Add prefill requests.
prompt_lens = []
seq_group_metadata_list = []
prefill_metadata_list = []
decode_metadata_list = []
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
for i in range(prefill_batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = SequenceData(list(range(prompt_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
prefill_metadata_list.append(seq_group_metadata)
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(prompt_len))
seq_data = SequenceData(prompt_toks)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)
(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.kv_cache_dtype == "auto"
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta = model_runner._prepare_prompt(
prefill_metadata_list).attn_metadata
decode_meta = model_runner._prepare_decode(
decode_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(prefill_meta),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(decode_meta),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]

View File

@ -1,5 +1,6 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
@ -8,4 +9,5 @@ __all__ = [
"AttentionMetadata",
"Attention",
"get_attn_backend",
"AttentionMetadataPerStage",
]

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
import torch
@ -47,7 +47,8 @@ class AttentionBackend(ABC):
@dataclass
class AttentionMetadata:
class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
@ -59,6 +60,41 @@ class AttentionMetadata:
}
T = TypeVar("T", bound=AttentionMetadataPerStage)
@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata: Optional[T]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata: Optional[T]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The kv cache's data type.
kv_cache_dtype: str
def __post_init__(self):
if self.num_prefill_tokens > 0:
assert self.num_prefills > 0
assert self.prefill_metadata is not None
if self.num_decode_tokens > 0:
assert self.decode_metadata is not None
class AttentionImpl(ABC):
@abstractmethod
@ -80,7 +116,7 @@ class AttentionImpl(ABC):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -11,7 +11,8 @@ import torch
from flash_attn import flash_attn_varlen_func
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend):
@dataclass
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
class FlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
@ -155,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -188,52 +195,70 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata.kv_cache_dtype,
kv_scale)
if attn_metadata.is_prompt:
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
output = flash_attn_varlen_func(
out = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=prefill_meta.max_prompt_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output = PagedAttention.forward_prefix(
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
self.alibi_slopes,
)
else:
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output = PagedAttention.forward_decode(
query,
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,

View File

@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@ -51,7 +52,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
@ -181,7 +188,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata,
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -218,9 +225,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_scale,
)
if attn_metadata.is_prompt:
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
@ -230,63 +253,69 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
output = self.attn_fuc(
out = self.attn_fuc(
query,
key,
value,
attn_metadata.prompt_lens,
prefill_meta.prompt_lens,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
output, _ = self.attn_func(
out, _ = self.attn_func(
query,
key,
value,
None,
attn_metadata.seq_start_loc,
attn_metadata.seq_start_loc,
attn_metadata.max_prompt_len,
attn_metadata.max_prompt_len,
prefill_meta.seq_start_loc,
prefill_meta.seq_start_loc,
prefill_meta.max_prompt_len,
prefill_meta.max_prompt_len,
True,
self.scale,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
output = self.attn_func(
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=prefill_meta.max_prompt_len,
softmax_scale=self.scale,
causal=True,
)
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
output = PagedAttention.forward_prefix(
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
self.alibi_slopes,
)
else:
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output = PagedAttention.forward_decode(
query,
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,

View File

@ -7,7 +7,8 @@ import torch
from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
@ -49,17 +50,14 @@ class TorchSDPABackend(AttentionBackend):
@dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor]
num_prompt_tokens: int
num_generation_tokens: int
max_subquery_len: Optional[int] = None
max_prompt_len: Optional[int] = None
@ -113,7 +111,7 @@ class TorchSDPABackendImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata,
attn_metadata: AttentionMetadata[TorchSDPAMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
@ -142,36 +140,51 @@ class TorchSDPABackendImpl(AttentionImpl):
attn_metadata.kv_cache_dtype,
kv_scale)
if attn_metadata.is_prompt:
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)
if attn_metadata.attn_bias is None:
if prefill_meta.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
attn_metadata.prompt_lens) # type: ignore
prefill_meta.prompt_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
attn_metadata.prompt_lens, self.sliding_window,
prefill_meta.prompt_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(attn_metadata.prompt_lens)
attn_metadata.attn_bias = att_masks
att_masks = [None] * len(prefill_meta.prompt_lens)
prefill_meta.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)
start = 0
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for prompt_len, mask in zip(attn_metadata.prompt_lens,
attn_metadata.attn_bias):
out = torch.empty((num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for prompt_len, mask in zip(prefill_meta.prompt_lens,
prefill_meta.attn_bias):
end = start + prompt_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
@ -181,28 +194,32 @@ class TorchSDPABackendImpl(AttentionImpl):
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out
out[start:end, :, :] = sub_out
start = end
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")
else:
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
output = PagedAttention.forward_decode(
query,
out = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
assert out.shape == output[num_prefill_tokens:].shape
output[num_prefill_tokens:]
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)

View File

@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.logger import init_logger
@ -54,7 +55,7 @@ class XFormersBackend(AttentionBackend):
@dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------|
@ -123,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
class XFormersImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def __init__(
@ -170,7 +171,7 @@ class XFormersImpl(AttentionImpl):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: XFormersMetadata,
attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -202,59 +203,61 @@ class XFormersImpl(AttentionImpl):
attn_metadata.kv_cache_dtype,
kv_scale)
if attn_metadata.is_prompt:
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
output = self._run_memory_efficient_xformers_forward(
query, key, value, attn_metadata)
out = self._run_memory_efficient_xformers_forward(
query, key, value, prefill_meta)
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
output = PagedAttention.forward_prefix(
out = PagedAttention.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.subquery_start_loc,
attn_metadata.prompt_lens_tensor,
attn_metadata.context_lens,
attn_metadata.max_subquery_len,
prefill_meta.block_tables,
prefill_meta.subquery_start_loc,
prefill_meta.prompt_lens_tensor,
prefill_meta.context_lens,
prefill_meta.max_subquery_len,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttention.forward_decode(
query,
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
@ -275,13 +278,30 @@ class XFormersImpl(AttentionImpl):
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
output: shape = [num_prefill_tokens, num_heads, head_size]
query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
@ -302,6 +322,7 @@ class XFormersImpl(AttentionImpl):
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
@ -312,14 +333,13 @@ class XFormersImpl(AttentionImpl):
attn_bias=attn_metadata.attn_bias[0],
p=0.0,
scale=self.scale)
return out.view_as(query)
return out.view_as(original_query)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(query)
output = torch.empty_like(original_query)
start = 0
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len
@ -331,7 +351,7 @@ class XFormersImpl(AttentionImpl):
p=0.0,
scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
output[start:end].copy_(out.view_as(original_query[start:end]))
start += prompt_len
return output

View File

@ -4,7 +4,8 @@ from typing import List, Optional
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.selector import get_attn_backend
@ -41,7 +42,7 @@ class Attention(nn.Module):
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float = 1.0,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,

View File

@ -13,11 +13,6 @@ _PARTITION_SIZE = 512
@dataclass
class PagedAttentionMetadata:
"""Metadata for PagedAttention."""
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
@ -31,7 +26,6 @@ class PagedAttentionMetadata:
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
kv_cache_dtype: str
class PagedAttention:

View File

@ -565,9 +565,16 @@ class SchedulerConfig:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
# If max_model_len is too short, use 2048 as the default value for
# higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
# For chunked prefill, choose the well-tuned batch size.
self.max_num_batched_tokens = 768
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.use_v2_block_manager = use_v2_block_manager

View File

@ -140,7 +140,11 @@ class SchedulerOutputs:
@property
def lora_requests(self) -> Set[LoRARequest]:
return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
return {
g.seq_group.lora_request
for g in self.scheduled_seq_groups
if g.seq_group.lora_request is not None
}
@dataclass
@ -826,13 +830,12 @@ class Scheduler:
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.decode_seq_groups +
running_scheduled.prefill_seq_groups +
swapped_in.decode_seq_groups +
swapped_in.prefill_seq_groups),
swapped_in.prefill_seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups),
num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)),
@ -907,7 +910,7 @@ class Scheduler:
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt = i < scheduler_outputs.num_prefill_groups
is_prompt = seq_group.is_prefill()
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,

View File

@ -173,10 +173,18 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src, group=group)
async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,

View File

@ -386,9 +386,8 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
type=bool,
default=False,
help='If True, the prefill requests can be chunked based on the '
action='store_true',
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument(

View File

@ -633,7 +633,10 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs)
# If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()

View File

@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
embedding_len = self.indices_len[3]
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding(
x + indices,
self.lora_a_stacked_2d,
)
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
indices = self.embeddings_indices[0][:embedding_len].view_as(x)
full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))

View File

@ -500,7 +500,8 @@ class SequenceGroup:
def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
for seq in self.get_seqs():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:

View File

@ -1,12 +1,14 @@
import contextlib
import time
from typing import Dict, List, Optional, Set, Tuple
from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
import numpy as np
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
]
class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
)
class PrepareDecodeMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner:
def __init__(
@ -152,10 +214,7 @@ class ModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
List[int], List[int], List[int], Set[LoRARequest],
torch.Tensor]:
assert len(seq_group_metadata_list) > 0
) -> PreparePromptMetadata:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
@ -169,6 +228,9 @@ class ModelRunner:
prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = []
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
@ -178,7 +240,8 @@ class ModelRunner:
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and computed_block_nums is not None):
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
"chunked prefill cannot be used with prefix caching "
"now.")
@ -190,13 +253,8 @@ class ModelRunner:
# it contains output tokens.
prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = len(prompt_tokens)
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert prefill_end == seq_data.get_len()
prompt_len = prefill_end
prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention.
@ -206,6 +264,14 @@ class ModelRunner:
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else:
prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this
@ -267,20 +333,8 @@ class ModelRunner:
max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens)
num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
lora_index_mapping = lora_index_mapping
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
@ -332,11 +386,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=0,
max_subquery_len=max_subquery_len,
max_context_len=None,
max_prompt_len=max_prompt_len,
@ -345,18 +396,25 @@ class ModelRunner:
context_lens=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests, multi_modal_input)
return PreparePromptMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
prompt_lens=prompt_lens,
subquery_lens=subquery_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
) -> PrepareDecodeMetadata:
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
@ -366,6 +424,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set()
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1
@ -424,15 +485,6 @@ class ModelRunner:
lora_index_mapping.append(0)
batch_size = graph_batch_size
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=self.device)
@ -440,9 +492,9 @@ class ModelRunner:
if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert context_lens.shape[0] == input_tokens.shape[0]
assert context_lens.shape[0] == input_positions.shape[0]
assert context_lens.shape[0] == slot_mapping.shape[0]
assert context_lens.shape[0] == len(input_tokens)
assert context_lens.shape[0] == len(input_positions)
assert context_lens.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
@ -464,11 +516,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=len(input_tokens),
max_subquery_len=None,
max_context_len=max_context_len,
max_prompt_len=None,
@ -477,10 +526,16 @@ class ModelRunner:
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests)
return PrepareDecodeMetadata(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
)
def _prepare_sample(
self,
@ -586,26 +641,66 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping, torch.Tensor]:
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
prefill_reqs = []
decode_reqs = []
for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests, multi_modal_input
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
subquery_lens = None
multi_modal_input = None
(
input_tokens,
input_positions,
prefill_attn_metadata,
prompt_lens,
subquery_lens,
lora_index_mapping,
lora_prompt_mapping,
lora_requests,
multi_modal_input,
slot_mapping,
) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config:
lora_mapping = LoRAMapping(
lora_index_mapping,
@ -615,6 +710,16 @@ class ModelRunner:
lora_mapping = None
# Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
@ -623,19 +728,49 @@ class ModelRunner:
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
}
metadata_dict.update(attn_metadata.asdict_zerocopy())
if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
@ -646,6 +781,23 @@ class ModelRunner:
perform_sampling=False,
)
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_input)
@ -663,8 +815,10 @@ class ModelRunner:
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Execute the model.
if attn_metadata.use_cuda_graph:
# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
@ -842,13 +996,10 @@ class ModelRunner:
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata.
attn_metadata = self.attn_backend.make_metadata(
decode_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None,
prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=batch_size,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture,
max_prompt_len=None,
@ -857,6 +1008,14 @@ class ModelRunner:
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
attn_metadata = AttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
prefill_metadata=None,
decode_metadata=decode_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
@ -950,8 +1109,8 @@ class CUDAGraphRunner:
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.context_lens,
"block_tables": attn_metadata.block_tables,
"context_lens": attn_metadata.decode_metadata.context_lens,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return
@ -972,10 +1131,10 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
non_blocking=True)
self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
non_blocking=True)
self.input_buffers["context_lens"].copy_(
attn_metadata.decode_metadata.context_lens, non_blocking=True)
self.input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph.
self.graph.replay()