[Core][5/N] Fully working chunked prefill e2e (#3884)

This commit is contained in:
SangBin Cho 2024-04-11 09:56:48 +09:00 committed by GitHub
parent 63e7176f26
commit 67b4221a61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 927 additions and 315 deletions

View File

@ -29,6 +29,8 @@ steps:
- pytest -v -s test_pynccl.py - pytest -v -s test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
- label: Engine Test - label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py command: pytest -v -s engine tokenization test_sequence.py test_config.py

View File

@ -177,8 +177,7 @@ if __name__ == '__main__':
help='block size of key/value cache') help='block size of key/value cache')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
type=bool, action='store_true',
default=False,
help='If True, the prefill requests can be chunked based on the ' help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument( parser.add_argument(

View File

@ -74,25 +74,31 @@ def run_vllm(
quantization_param_path: Optional[str], quantization_param_path: Optional[str],
device: str, device: str,
enable_prefix_caching: bool, enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(model=model, llm = LLM(
tokenizer=tokenizer, model=model,
quantization=quantization, tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size, quantization=quantization,
seed=seed, tensor_parallel_size=tensor_parallel_size,
trust_remote_code=trust_remote_code, seed=seed,
dtype=dtype, trust_remote_code=trust_remote_code,
max_model_len=max_model_len, dtype=dtype,
gpu_memory_utilization=gpu_memory_utilization, max_model_len=max_model_len,
enforce_eager=enforce_eager, gpu_memory_utilization=gpu_memory_utilization,
kv_cache_dtype=kv_cache_dtype, enforce_eager=enforce_eager,
quantization_param_path=quantization_param_path, kv_cache_dtype=kv_cache_dtype,
device=device, quantization_param_path=quantization_param_path,
enable_prefix_caching=enable_prefix_caching, device=device,
download_dir=download_dir) enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
# Add the requests to the engine. # Add the requests to the engine.
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer, elapsed_time = run_vllm(
args.quantization, args.tensor_parallel_size, requests, args.model, args.tokenizer, args.quantization,
args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.trust_remote_code, args.dtype, args.max_model_len,
args.max_model_len, args.enforce_eager, args.enforce_eager, args.kv_cache_dtype,
args.kv_cache_dtype, args.quantization_param_path, args.device,
args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill,
args.enable_prefix_caching, args.max_num_batched_tokens, args.gpu_memory_utilization,
args.gpu_memory_utilization, args.download_dir) args.download_dir)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -335,6 +341,14 @@ if __name__ == "__main__":
"--enable-prefix-caching", "--enable-prefix-caching",
action='store_true', action='store_true',
help="enable automatic prefix caching for vLLM backend.") help="enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir', parser.add_argument('--download-dir',
type=str, type=str,
default=None, default=None,

View File

@ -0,0 +1,70 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
It tests chunked prefill. Chunked prefill can be enabled by
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`.
"""
import pytest
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
and not enforce_eager):
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
"for high TP to save testing time.")
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
print(vllm_outputs[0])
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@ -104,10 +104,10 @@ def test_chunk():
# One chunked prefill, and one decoding. # One chunked prefill, and one decoding.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running) assert set(get_sequence_groups(out)) == set(running)
# The first one is decoding. # The first one is prefill. Scheduler guarantees ordering.
assert seq_group_meta[0].token_chunk_size == 1 assert seq_group_meta[0].token_chunk_size == 56
# The second one is a chunked prefill. # The second one is a chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56 assert seq_group_meta[1].token_chunk_size == 1
assert out.num_prefill_groups == 1 assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 57 assert out.num_batched_tokens == 57
@ -157,12 +157,12 @@ def test_complex():
# Decoding & chunked prefill & first chunk of 3rd request is scheduled. # Decoding & chunked prefill & first chunk of 3rd request is scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 3 assert len(get_sequence_groups(out)) == 3
# The first one is decoding. # The first one is the first chunked prefill.
assert seq_group_meta[0].token_chunk_size == 1 assert seq_group_meta[0].token_chunk_size == 7
# The second one is a chunked prefill. # The second one is the second new chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56 assert seq_group_meta[1].token_chunk_size == 56
# The third one is also chunked. # The last one is decode.
assert seq_group_meta[2].token_chunk_size == 7 assert seq_group_meta[2].token_chunk_size == 1
# Two of them are in chunked prefill. # Two of them are in chunked prefill.
assert out.num_prefill_groups == 2 assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64 assert out.num_batched_tokens == 64

View File

@ -33,11 +33,16 @@ def test_models(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
hf_model = hf_runner(model, dtype=dtype) hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model del hf_model
vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model del vllm_model

View File

@ -0,0 +1,66 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_chunked_prefill_distributed.py
```
"""
import os
import pytest
import torch
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
assert chunked_prefill_token_size != -1
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@ -141,7 +141,7 @@ def server(zephyr_lora_files):
"--max-cpu-loras", "--max-cpu-loras",
"2", "2",
"--max-num-seqs", "--max-num-seqs",
"128" "128",
]) ])
ray.get(server_runner.ready.remote()) ray.get(server_runner.ready.remote())
yield server_runner yield server_runner

View File

@ -12,7 +12,7 @@ MODELS = [
"gpt2", "gpt2",
"bigcode/tiny_starcoder_py", "bigcode/tiny_starcoder_py",
"EleutherAI/pythia-70m", "EleutherAI/pythia-70m",
"bigscience/bloom-560m", "bigscience/bloom-560m", # Testing alibi slopes.
"microsoft/phi-2", "microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t", "stabilityai/stablelm-3b-4e1t",
# "allenai/OLMo-1B", # Broken # "allenai/OLMo-1B", # Broken

View File

@ -1,14 +1,18 @@
import pytest import pytest
import torch import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig, SchedulerConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size): def test_prepare_prompt(batch_size):
model_runner = ModelRunner(None, None, None, None, None) scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(None, None, scheduler_config, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] prompt_lens = []
@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size):
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += prompt_len selected_token_start_idx += prompt_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) _, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens assert return_prompt_lens == prompt_lens
assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.prompt_lens_tensor, assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device)) torch.tensor(prompt_lens, device=device))
assert attn_metadata.prompt_lens == prompt_lens assert attn_metadata.prompt_lens == prompt_lens
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
assert attn_metadata.num_generation_tokens == 0
assert attn_metadata.max_prompt_len == max(prompt_lens) assert attn_metadata.max_prompt_len == max(prompt_lens)
# Test subquery start locs. # Test subquery start locs.
@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.block_tables, expected) assert torch.allclose(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False assert attn_metadata.use_cuda_graph is False
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), ) assert len(input_tokens) == sum(prompt_lens)
assert input_positions.shape == (sum(prompt_lens), ) assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens=prompt_lens) subquery_lens=prompt_lens)
assert input_tokens.shape == (sum(prompt_lens), ) assert len(input_tokens) == sum(prompt_lens)
assert input_positions.shape == (sum(prompt_lens), ) assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
device=actual.device, device=actual.device,
dtype=actual.dtype) dtype=actual.dtype)
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
torch.testing.assert_close(input_tokens, input_positions) assert input_tokens == input_positions
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size):
revision=None, revision=None,
enforce_eager=False, enforce_eager=False,
) )
model_runner = ModelRunner(model_config, None, None, None, None) scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(model_config, None, scheduler_config, None,
None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] prompt_lens = []
@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _ = ( input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list)) model_runner._prepare_decode(seq_group_metadata_list))
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.is_prompt is False assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None assert attn_metadata.prompt_lens is None
assert attn_metadata.num_prompt_tokens == 0
assert attn_metadata.num_generation_tokens == expected_bs
assert attn_metadata.max_prompt_len is None assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None assert attn_metadata.seq_start_loc is None
@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size):
model_runner.get_max_block_per_batch()) model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True assert attn_metadata.use_cuda_graph is True
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (expected_bs, ) assert len(input_tokens) == expected_bs
assert input_positions.shape == (expected_bs, ) assert len(input_positions) == expected_bs
torch.testing.assert_close(input_tokens, input_positions) assert input_tokens == input_positions
# Verify Sampling # Verify Sampling
expected_selected_token_indices = [] expected_selected_token_indices = []
@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size):
device=actual.device, device=actual.device,
dtype=actual.dtype) dtype=actual.dtype)
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
def test_empty_seq_group():
"""Verify prepare prompt and decode returns empty output."""
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_prompt_lens) == 0
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
def get_world_size(group=None):
return 1
def mock_get_process_group_ranks(group=None):
return [0]
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
mock_get_process_group_ranks)
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=enforce_eager,
)
scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=True)
model_runner = ModelRunner(model_config,
None,
scheduler_config,
None,
None,
is_driver_worker=True)
model_runner.set_block_size(16)
# Add prefill requests.
prompt_lens = []
seq_group_metadata_list = []
prefill_metadata_list = []
decode_metadata_list = []
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
for i in range(prefill_batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = SequenceData(list(range(prompt_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
prefill_metadata_list.append(seq_group_metadata)
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(prompt_len))
seq_data = SequenceData(prompt_toks)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)
(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.kv_cache_dtype == "auto"
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta = model_runner._prepare_prompt(
prefill_metadata_list).attn_metadata
decode_meta = model_runner._prepare_decode(
decode_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(prefill_meta),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(decode_meta),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]

View File

@ -1,5 +1,6 @@
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
@ -8,4 +9,5 @@ __all__ = [
"AttentionMetadata", "AttentionMetadata",
"Attention", "Attention",
"get_attn_backend", "get_attn_backend",
"AttentionMetadataPerStage",
] ]

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
import torch import torch
@ -47,7 +47,8 @@ class AttentionBackend(ABC):
@dataclass @dataclass
class AttentionMetadata: class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def asdict_zerocopy(self) -> Dict[str, Any]: def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying.""" """Similar to dataclasses.asdict, but avoids deepcopying."""
@ -59,6 +60,41 @@ class AttentionMetadata:
} }
T = TypeVar("T", bound=AttentionMetadataPerStage)
@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata: Optional[T]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata: Optional[T]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The kv cache's data type.
kv_cache_dtype: str
def __post_init__(self):
if self.num_prefill_tokens > 0:
assert self.num_prefills > 0
assert self.prefill_metadata is not None
if self.num_decode_tokens > 0:
assert self.decode_metadata is not None
class AttentionImpl(ABC): class AttentionImpl(ABC):
@abstractmethod @abstractmethod
@ -80,7 +116,7 @@ class AttentionImpl(ABC):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -11,7 +11,8 @@ import torch
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend):
@dataclass @dataclass
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class FlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend. """Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is NOTE: Any python object stored here is not updated when it is
@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens: Optional[List[int]] prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor. # prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor] prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. # NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
class FlashAttentionImpl(AttentionImpl): class FlashAttentionImpl(AttentionImpl):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->| |<--------------- num_prefill_tokens ----------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows: Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->| |<----------------- num_decode_tokens ------------------>|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used. Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding. Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens The prompts might have different lengths, while the generation tokens
always have length 1. always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
""" """
def __init__( def __init__(
@ -155,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -188,52 +195,70 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
kv_scale) kv_scale)
if attn_metadata.is_prompt: num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
output = flash_attn_varlen_func( out = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
cu_seqlens_q=attn_metadata.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len, max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len, max_seqlen_k=prefill_meta.max_prompt_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to # TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache, # deal with different data types between KV and FP8 KV cache,
# to be addressed separately. # to be addressed separately.
output = PagedAttention.forward_prefix( output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query, query,
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, prefill_meta.block_tables,
attn_metadata.subquery_start_loc, prefill_meta.subquery_start_loc,
attn_metadata.prompt_lens_tensor, prefill_meta.prompt_lens_tensor,
attn_metadata.context_lens, prefill_meta.context_lens,
attn_metadata.max_subquery_len, prefill_meta.max_subquery_len,
self.alibi_slopes, self.alibi_slopes,
) )
else: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, decode_meta.block_tables,
attn_metadata.context_lens, decode_meta.context_lens,
attn_metadata.max_context_len, decode_meta.max_context_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,

View File

@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -51,7 +52,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@dataclass @dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend. """Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is NOTE: Any python object stored here is not updated when it is
@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens: Optional[List[int]] prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor. # prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor] prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. # NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
The prompts might have different lengths, while the generation tokens The prompts might have different lengths, while the generation tokens
always have length 1. always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
""" """
def __init__( def __init__(
@ -181,7 +188,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
kv_scale: float = 1.0, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -218,9 +225,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_scale, kv_scale,
) )
if attn_metadata.is_prompt: num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention # triton attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
@ -230,63 +253,69 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = self.repeat_kv(key, self.num_queries_per_kv) key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn: if self.use_naive_attn:
output = self.attn_fuc( out = self.attn_fuc(
query, query,
key, key,
value, value,
attn_metadata.prompt_lens, prefill_meta.prompt_lens,
self.scale, self.scale,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
output, _ = self.attn_func( out, _ = self.attn_func(
query, query,
key, key,
value, value,
None, None,
attn_metadata.seq_start_loc, prefill_meta.seq_start_loc,
attn_metadata.seq_start_loc, prefill_meta.seq_start_loc,
attn_metadata.max_prompt_len, prefill_meta.max_prompt_len,
attn_metadata.max_prompt_len, prefill_meta.max_prompt_len,
True, True,
self.scale, self.scale,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
output = self.attn_func( out = self.attn_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
cu_seqlens_q=attn_metadata.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len, max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len, max_seqlen_k=prefill_meta.max_prompt_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
output = PagedAttention.forward_prefix( output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query, query,
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, prefill_meta.block_tables,
attn_metadata.subquery_start_loc, prefill_meta.subquery_start_loc,
attn_metadata.prompt_lens_tensor, prefill_meta.prompt_lens_tensor,
attn_metadata.context_lens, prefill_meta.context_lens,
attn_metadata.max_subquery_len, prefill_meta.max_subquery_len,
self.alibi_slopes, self.alibi_slopes,
) )
else:
if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, decode_meta.block_tables,
attn_metadata.context_lens, decode_meta.context_lens,
attn_metadata.max_context_len, decode_meta.max_context_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,

View File

@ -7,7 +7,8 @@ import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
@ -49,17 +50,14 @@ class TorchSDPABackend(AttentionBackend):
@dataclass @dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for TorchSDPABackend. """Metadata for TorchSDPABackend.
""" """
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]] prompt_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor] prompt_lens_tensor: Optional[torch.Tensor]
num_prompt_tokens: int
num_generation_tokens: int
max_subquery_len: Optional[int] = None max_subquery_len: Optional[int] = None
max_prompt_len: Optional[int] = None max_prompt_len: Optional[int] = None
@ -113,7 +111,7 @@ class TorchSDPABackendImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata, attn_metadata: AttentionMetadata[TorchSDPAMetadata],
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention. """Forward pass with torch SDPA and PagedAttention.
@ -142,36 +140,51 @@ class TorchSDPABackendImpl(AttentionImpl):
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
kv_scale) kv_scale)
if attn_metadata.is_prompt: num_prefill_tokens = attn_metadata.num_prefill_tokens
if (kv_cache is None or attn_metadata.block_tables.numel() == 0): num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1) key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv, value = value.repeat_interleave(self.num_queries_per_kv,
dim=1) dim=1)
if attn_metadata.attn_bias is None: if prefill_meta.attn_bias is None:
if self.alibi_slopes is not None: if self.alibi_slopes is not None:
att_masks = _make_alibi_bias( att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype, self.alibi_slopes, query.dtype,
attn_metadata.prompt_lens) # type: ignore prefill_meta.prompt_lens) # type: ignore
elif self.sliding_window is not None: elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias( att_masks = _make_sliding_window_bias(
attn_metadata.prompt_lens, self.sliding_window, prefill_meta.prompt_lens, self.sliding_window,
query.dtype) # type: ignore query.dtype) # type: ignore
else: else:
att_masks = [None] * len(attn_metadata.prompt_lens) att_masks = [None] * len(prefill_meta.prompt_lens)
attn_metadata.attn_bias = att_masks prefill_meta.attn_bias = att_masks
query = query.movedim(0, query.dim() - 2) query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2) key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2) value = value.movedim(0, value.dim() - 2)
start = 0 start = 0
output = torch.empty( out = torch.empty((num_tokens, self.num_heads, self.head_size),
(num_tokens, self.num_heads, self.head_size), dtype=query.dtype)
dtype=query.dtype) for prompt_len, mask in zip(prefill_meta.prompt_lens,
for prompt_len, mask in zip(attn_metadata.prompt_lens, prefill_meta.attn_bias):
attn_metadata.attn_bias):
end = start + prompt_len end = start + prompt_len
sub_out = scaled_dot_product_attention( sub_out = scaled_dot_product_attention(
query[:, start:end, :], query[:, start:end, :],
@ -181,28 +194,32 @@ class TorchSDPABackendImpl(AttentionImpl):
dropout_p=0.0, dropout_p=0.0,
is_causal=not self.need_mask, is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0) scale=self.scale).movedim(query.dim() - 2, 0)
output[start:end, :, :] = sub_out out[start:end, :, :] = sub_out
start = end start = end
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
raise RuntimeError( raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.") "Torch SDPA backend doesn't support prefix decoding.")
else: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output = PagedAttention.forward_decode( out = PagedAttention.forward_decode(
query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, decode_meta.block_tables,
attn_metadata.context_lens, decode_meta.context_lens,
attn_metadata.max_context_len, decode_meta.max_context_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
kv_scale, kv_scale,
) )
assert out.shape == output[num_prefill_tokens:].shape
output[num_prefill_tokens:]
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)

View File

@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias) LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -54,7 +55,7 @@ class XFormersBackend(AttentionBackend):
@dataclass @dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for XFormersbackend. """Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is NOTE: Any python object stored here is not updated when it is
@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding. # (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]] prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor. # prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor] prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. # NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
@ -123,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
class XFormersImpl(AttentionImpl): class XFormersImpl(AttentionImpl):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->| |<--------------- num_prefill_tokens ----------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows: Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->| |<----------------- num_decode_tokens ------------------>|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used. Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding. Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens The prompts might have different lengths, while the generation tokens
always have length 1. always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
""" """
def __init__( def __init__(
@ -170,7 +171,7 @@ class XFormersImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: XFormersMetadata, attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
@ -202,59 +203,61 @@ class XFormersImpl(AttentionImpl):
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
kv_scale) kv_scale)
if attn_metadata.is_prompt: num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention. # normal attention.
# block tables are empty if the prompt does not have a cached # block tables are empty if the prompt does not have a cached
# prefix. # prefix.
if self.num_kv_heads != self.num_heads: out = self._run_memory_efficient_xformers_forward(
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, query, key, value, prefill_meta)
# project the key and value tensors to the desired number of assert out.shape == output[:num_prefill_tokens].shape
# heads. output[:num_prefill_tokens] = out
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
output = self._run_memory_efficient_xformers_forward(
query, key, value, attn_metadata)
else: else:
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to # TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache, # deal with different data types between KV and FP8 KV cache,
# to be addressed separately. # to be addressed separately.
output = PagedAttention.forward_prefix( out = PagedAttention.forward_prefix(
query, query,
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, prefill_meta.block_tables,
attn_metadata.subquery_start_loc, prefill_meta.subquery_start_loc,
attn_metadata.prompt_lens_tensor, prefill_meta.prompt_lens_tensor,
attn_metadata.context_lens, prefill_meta.context_lens,
attn_metadata.max_subquery_len, prefill_meta.max_subquery_len,
self.alibi_slopes, self.alibi_slopes,
) )
else: assert output[:num_prefill_tokens].shape == out.shape
# Decoding run. output[:num_prefill_tokens] = out
output = PagedAttention.forward_decode(
query, if decode_meta := attn_metadata.decode_metadata:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, decode_meta.block_tables,
attn_metadata.context_lens, decode_meta.context_lens,
attn_metadata.max_context_len, decode_meta.max_context_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
@ -275,13 +278,30 @@ class XFormersImpl(AttentionImpl):
"""Attention for 1D query of multiple prompts. Multiple prompt """Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input. tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args: Args:
output: shape = [num_prompt_tokens, num_heads, head_size] output: shape = [num_prefill_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size] key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
""" """
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at # Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration. # the very attention layer of every iteration.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
@ -302,6 +322,7 @@ class XFormersImpl(AttentionImpl):
# TODO(woosuk): Too many view operations. Let's try to reduce # TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability. # them in the future for code readability.
if self.alibi_slopes is None: if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0) query = query.unsqueeze(0)
key = key.unsqueeze(0) key = key.unsqueeze(0)
value = value.unsqueeze(0) value = value.unsqueeze(0)
@ -312,14 +333,13 @@ class XFormersImpl(AttentionImpl):
attn_bias=attn_metadata.attn_bias[0], attn_bias=attn_metadata.attn_bias[0],
p=0.0, p=0.0,
scale=self.scale) scale=self.scale)
return out.view_as(original_query)
return out.view_as(query)
# Attention with alibi slopes. # Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence # FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by # lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts. # one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(query) output = torch.empty_like(original_query)
start = 0 start = 0
for i, prompt_len in enumerate(attn_metadata.prompt_lens): for i, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len end = start + prompt_len
@ -331,7 +351,7 @@ class XFormersImpl(AttentionImpl):
p=0.0, p=0.0,
scale=self.scale) scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0)) output[start:end].copy_(out.view_as(original_query[start:end]))
start += prompt_len start += prompt_len
return output return output

View File

@ -4,7 +4,8 @@ from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
@ -41,7 +42,7 @@ class Attention(nn.Module):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float = 1.0, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata, return self.impl.forward(query, key, value, kv_cache, attn_metadata,

View File

@ -13,11 +13,6 @@ _PARTITION_SIZE = 512
@dataclass @dataclass
class PagedAttentionMetadata: class PagedAttentionMetadata:
"""Metadata for PagedAttention.""" """Metadata for PagedAttention."""
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per # (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new # sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token. # tokens. When it is for decoding, it includes a new token.
@ -31,7 +26,6 @@ class PagedAttentionMetadata:
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured. # captured.
block_tables: Optional[torch.Tensor] block_tables: Optional[torch.Tensor]
kv_cache_dtype: str
class PagedAttention: class PagedAttention:

View File

@ -565,9 +565,16 @@ class SchedulerConfig:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
else: else:
# If max_model_len is too short, use 2048 as the default value for if enable_chunked_prefill:
# higher throughput. # For chunked prefill, choose the well-tuned batch size.
self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_batched_tokens = 768
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.use_v2_block_manager = use_v2_block_manager self.use_v2_block_manager = use_v2_block_manager

View File

@ -140,7 +140,11 @@ class SchedulerOutputs:
@property @property
def lora_requests(self) -> Set[LoRARequest]: def lora_requests(self) -> Set[LoRARequest]:
return {g.seq_group.lora_request for g in self.scheduled_seq_groups} return {
g.seq_group.lora_request
for g in self.scheduled_seq_groups
if g.seq_group.lora_request is not None
}
@dataclass @dataclass
@ -826,13 +830,12 @@ class Scheduler:
# Update swapped requests. # Update swapped requests.
self.swapped = remaining_swapped self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs( return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups + scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.decode_seq_groups +
running_scheduled.prefill_seq_groups + running_scheduled.prefill_seq_groups +
swapped_in.decode_seq_groups + swapped_in.prefill_seq_groups +
swapped_in.prefill_seq_groups), running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups),
num_prefill_groups=(len(prefills.seq_groups) + num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) + len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)), len(running_scheduled.prefill_seq_groups)),
@ -907,7 +910,7 @@ class Scheduler:
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
is_prompt = i < scheduler_outputs.num_prefill_groups is_prompt = seq_group.is_prefill()
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=is_prompt,

View File

@ -173,10 +173,18 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list([metadata_list], torch.distributed.broadcast_object_list([metadata_list],
src=src, src=src,
group=group) group=group)
async_handles = []
for key, value in metadata_list: for key, value in metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = tensor_dict[key] tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src, group=group) async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()
else: else:
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,

View File

@ -386,9 +386,8 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.') 'prompt latency) before scheduling next prompt.')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
type=bool, action='store_true',
default=False, help='If set, the prefill requests can be chunked based on the '
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument( parser.add_argument(

View File

@ -633,7 +633,10 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs) # If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()

View File

@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) embedding_len = self.indices_len[3]
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices, x + indices,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) indices = self.embeddings_indices[0][:embedding_len].view_as(x)
full_output = self.base_layer.forward( full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask)) x.add_(indices * added_tokens_mask))

View File

@ -500,7 +500,8 @@ class SequenceGroup:
def get_num_uncomputed_tokens(self) -> int: def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0 num_uncomputed_tokens = 0
for seq in self.get_seqs(): for seq in self.get_seqs():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens return num_uncomputed_tokens
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:

View File

@ -1,12 +1,14 @@
import contextlib import contextlib
import time import time
from typing import Dict, List, Optional, Set, Tuple from enum import IntEnum
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig) SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
] ]
class PreparePromptMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadataPerStage]
prompt_lens: List[int]
subquery_lens: List[int]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
multi_modal_input: Optional[torch.Tensor]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PreparePromptMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
prompt_lens=[],
subquery_lens=[],
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
multi_modal_input=None,
slot_mapping=[],
)
class PrepareDecodeMetadata(NamedTuple):
input_tokens: List[int]
input_positions: List[int]
attn_metadata: Optional[AttentionMetadata]
lora_index_mapping: List[int]
lora_prompt_mapping: List[int]
lora_requests: Set[LoRARequest]
slot_mapping: List[int]
@classmethod
def empty(cls):
return PrepareDecodeMetadata(
input_tokens=[],
input_positions=[],
attn_metadata=None,
lora_index_mapping=[],
lora_prompt_mapping=[],
lora_requests=set(),
slot_mapping=[],
)
# How batches are constructed.
class BatchType(IntEnum):
# Every batch is prefill.
PREFILL = 0
# Every batch is decode.
DECODE = 1
# Batch is a mixture of prefill and decode.
MIXED = 2
class ModelRunner: class ModelRunner:
def __init__( def __init__(
@ -152,10 +214,7 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> PreparePromptMetadata:
List[int], List[int], List[int], Set[LoRARequest],
torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
@ -169,6 +228,9 @@ class ModelRunner:
prefix_block_tables: List[List[int]] = [] prefix_block_tables: List[List[int]] = []
multi_modal_input_list: List[torch.Tensor] = [] multi_modal_input_list: List[torch.Tensor] = []
if len(seq_group_metadata_list) == 0:
return PreparePromptMetadata.empty()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
@ -178,7 +240,8 @@ class ModelRunner:
computed_block_nums = seq_group_metadata.computed_block_nums computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled and self.scheduler_config.chunked_prefill_enabled
and computed_block_nums is not None): and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError( raise RuntimeError(
"chunked prefill cannot be used with prefix caching " "chunked prefill cannot be used with prefix caching "
"now.") "now.")
@ -190,13 +253,8 @@ class ModelRunner:
# it contains output tokens. # it contains output tokens.
prefill_end = min(seq_data.get_len(), prefill_end = min(seq_data.get_len(),
computed_len + token_chunk_size) computed_len + token_chunk_size)
# TODO(sang): Rename it after chunked prefill is introduced.
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
prompt_len = len(prompt_tokens) prompt_len = prefill_end
# Right now, the prefill_end is always same as the length of
# sequence. However, once chunked prefill is introduced, this
# assumption can be changed.
assert prefill_end == seq_data.get_len()
prompt_lens.append(prompt_len) prompt_lens.append(prompt_len)
# NOTE: This only works for oooooooxxx style attention. # NOTE: This only works for oooooooxxx style attention.
@ -206,6 +264,14 @@ class ModelRunner:
computed_len = len(computed_block_nums) * self.block_size computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:] prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums) prefix_block_tables.append(computed_block_nums)
elif self.scheduler_config.chunked_prefill_enabled:
if seq_group_metadata.block_tables is not None:
# Prefill has chunked before.
block_table = seq_group_metadata.block_tables[seq_id]
prefix_block_tables.append(block_table)
else:
# The first prefill.
prefix_block_tables.append([])
else: else:
prefix_block_tables.append([]) prefix_block_tables.append([])
# Right now, prefill start is always 0. However, this # Right now, prefill start is always 0. However, this
@ -267,20 +333,8 @@ class ModelRunner:
max_subquery_len = max(subquery_lens) max_subquery_len = max(subquery_lens)
max_prompt_len = max(prompt_lens) max_prompt_len = max(prompt_lens)
num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0 assert max_subquery_len > 0
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
lora_index_mapping = lora_index_mapping
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
@ -332,11 +386,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
prompt_lens_tensor=prompt_lens_tensor, prompt_lens_tensor=prompt_lens_tensor,
num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=0,
max_subquery_len=max_subquery_len, max_subquery_len=max_subquery_len,
max_context_len=None, max_context_len=None,
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,
@ -345,18 +396,25 @@ class ModelRunner:
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, return PreparePromptMetadata(
lora_requests, multi_modal_input) input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
prompt_lens=prompt_lens,
subquery_lens=subquery_lens,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
multi_modal_input=multi_modal_input,
slot_mapping=slot_mapping,
)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], ) -> PrepareDecodeMetadata:
List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
slot_mapping: List[int] = [] slot_mapping: List[int] = []
@ -366,6 +424,9 @@ class ModelRunner:
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
if len(seq_group_metadata_list) == 0:
return PrepareDecodeMetadata.empty()
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt assert not seq_group_metadata.is_prompt
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
@ -424,15 +485,6 @@ class ModelRunner:
lora_index_mapping.append(0) lora_index_mapping.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
@ -440,9 +492,9 @@ class ModelRunner:
if use_captured_graph: if use_captured_graph:
# When using cuda-graph all these tensors should be # When using cuda-graph all these tensors should be
# padded. # padded.
assert context_lens.shape[0] == input_tokens.shape[0] assert context_lens.shape[0] == len(input_tokens)
assert context_lens.shape[0] == input_positions.shape[0] assert context_lens.shape[0] == len(input_positions)
assert context_lens.shape[0] == slot_mapping.shape[0] assert context_lens.shape[0] == len(slot_mapping)
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
@ -464,11 +516,8 @@ class ModelRunner:
attn_metadata = self.attn_backend.make_metadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None, prompt_lens=None,
prompt_lens_tensor=None, prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=len(input_tokens),
max_subquery_len=None, max_subquery_len=None,
max_context_len=max_context_len, max_context_len=max_context_len,
max_prompt_len=None, max_prompt_len=None,
@ -477,10 +526,16 @@ class ModelRunner:
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, attn_metadata, return PrepareDecodeMetadata(
lora_index_mapping, lora_prompt_mapping, lora_requests) input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
lora_index_mapping=lora_index_mapping,
lora_prompt_mapping=lora_prompt_mapping,
lora_requests=lora_requests,
slot_mapping=slot_mapping,
)
def _prepare_sample( def _prepare_sample(
self, self,
@ -586,26 +641,66 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping, torch.Tensor]: Set[int], LoRAMapping, torch.Tensor]:
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or prefill_reqs = []
# all decodes. decode_reqs = []
is_prompt = seq_group_metadata_list[0].is_prompt for seq_group_meta in seq_group_metadata_list:
if seq_group_meta.is_prompt:
prefill_reqs.append(seq_group_meta)
else:
decode_reqs.append(seq_group_meta)
# Prepare input tensors. # Prepare input tensors.
if is_prompt: (
(input_tokens, input_positions, attn_metadata, prompt_lens, input_tokens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, input_positions,
lora_requests, multi_modal_input prefill_attn_metadata,
) = self._prepare_prompt(seq_group_metadata_list) prompt_lens,
else: subquery_lens,
(input_tokens, input_positions, attn_metadata, lora_index_mapping,
lora_index_mapping, lora_prompt_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list) lora_requests,
prompt_lens = [] multi_modal_input,
subquery_lens = None slot_mapping,
multi_modal_input = None ) = self._prepare_prompt(prefill_reqs)
(
decode_input_tokens,
decode_input_positions,
decode_attn_metadata,
decode_lora_index_mapping,
decode_lora_prompt_mapping,
decode_lora_requests,
decode_slot_mapping,
) = self._prepare_decode(decode_reqs)
sampling_metadata = self._prepare_sample(seq_group_metadata_list, sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens) subquery_lens)
if not self.scheduler_config.chunked_prefill_enabled:
assert (len(prefill_reqs) and len(decode_reqs)) == 0
num_prefills = len(prompt_lens)
num_prefill_tokens = len(input_tokens)
num_decode_tokens = len(decode_input_tokens)
# Coalesce tensors. Note that attn_metadata is currently not
# coalesced for simplicity.
input_tokens.extend(decode_input_tokens)
input_positions.extend(decode_input_positions)
slot_mapping.extend(decode_slot_mapping)
lora_index_mapping.extend(decode_lora_index_mapping)
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
lora_requests.update(decode_lora_requests)
input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device)
if self.lora_config: if self.lora_config:
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
lora_index_mapping, lora_index_mapping,
@ -615,6 +710,16 @@ class ModelRunner:
lora_mapping = None lora_mapping = None
# Broadcast the metadata. # Broadcast the metadata.
# If batch contains both prefill and decode, it sends 2 broadcasts.
# If it only contains 1 type, it triggers a single broadcast.
if (prefill_attn_metadata is not None
and decode_attn_metadata is not None):
batch_type = BatchType.MIXED
elif prefill_attn_metadata is not None:
batch_type = BatchType.PREFILL
else:
batch_type = BatchType.DECODE
metadata_dict = { metadata_dict = {
"input_tokens": input_tokens, "input_tokens": input_tokens,
"input_positions": input_positions, "input_positions": input_positions,
@ -623,19 +728,49 @@ class ModelRunner:
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
"multi_modal_input": multi_modal_input, "multi_modal_input": multi_modal_input,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
"batch_type": batch_type,
} }
metadata_dict.update(attn_metadata.asdict_zerocopy()) if prefill_attn_metadata is not None:
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
else:
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
# Broadcast decode attn metadata for mixed batch type.
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
# We can potentially reduce the overhead by coelescing tensors.
if batch_type == BatchType.MIXED:
assert decode_attn_metadata is not None
metadata_dict = decode_attn_metadata.asdict_zerocopy()
broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens") input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions") input_positions = metadata_dict.pop("input_positions")
slot_mapping = metadata_dict.pop("slot_mapping")
num_prefills = metadata_dict.pop("num_prefills")
selected_token_indices = metadata_dict.pop( selected_token_indices = metadata_dict.pop(
"selected_token_indices") "selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping") lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests") lora_requests = metadata_dict.pop("lora_requests")
multi_modal_input = metadata_dict.pop("multi_modal_input") multi_modal_input = metadata_dict.pop("multi_modal_input")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict) num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
batch_type = metadata_dict.pop("batch_type")
# Create an attention metadata.
prefill_attn_metadata = None
decode_attn_metadata = None
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
prefill_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
@ -646,6 +781,23 @@ class ModelRunner:
perform_sampling=False, perform_sampling=False,
) )
# if it is a mixed batch, decode attn_metadata is broadcasted
# separately.
if batch_type == BatchType.MIXED:
metadata_dict = broadcast_tensor_dict(src=0)
decode_attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
attn_metadata = AttentionMetadata(
num_prefills=num_prefills,
slot_mapping=slot_mapping,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
prefill_metadata=prefill_attn_metadata,
decode_metadata=decode_attn_metadata,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, attn_metadata, return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping, sampling_metadata, lora_requests, lora_mapping,
multi_modal_input) multi_modal_input)
@ -663,8 +815,10 @@ class ModelRunner:
if self.lora_config: if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping) self.set_active_loras(lora_requests, lora_mapping)
# Execute the model. # Currently cuda graph is only supported by the decode phase.
if attn_metadata.use_cuda_graph: prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[graph_batch_size]
else: else:
@ -842,13 +996,10 @@ class ModelRunner:
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
# Create dummy attn_metadata. # Create dummy attn_metadata.
attn_metadata = self.attn_backend.make_metadata( decode_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None, prompt_lens=None,
prompt_lens_tensor=None, prompt_lens_tensor=None,
num_prompt_tokens=0,
num_generation_tokens=batch_size,
max_subquery_len=None, max_subquery_len=None,
max_context_len=self.max_context_len_to_capture, max_context_len=self.max_context_len_to_capture,
max_prompt_len=None, max_prompt_len=None,
@ -857,6 +1008,14 @@ class ModelRunner:
context_lens=context_lens[:batch_size], context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size], block_tables=block_tables[:batch_size],
use_cuda_graph=True, use_cuda_graph=True,
)
attn_metadata = AttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=slot_mapping[:batch_size],
prefill_metadata=None,
decode_metadata=decode_metadata,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
@ -950,8 +1109,8 @@ class CUDAGraphRunner:
"positions": positions, "positions": positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.context_lens, "context_lens": attn_metadata.decode_metadata.context_lens,
"block_tables": attn_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
return return
@ -972,10 +1131,10 @@ class CUDAGraphRunner:
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) non_blocking=True)
self.input_buffers["context_lens"].copy_(attn_metadata.context_lens, self.input_buffers["context_lens"].copy_(
non_blocking=True) attn_metadata.decode_metadata.context_lens, non_blocking=True)
self.input_buffers["block_tables"].copy_(attn_metadata.block_tables, self.input_buffers["block_tables"].copy_(
non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()