mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:05:02 +08:00
[Core][5/N] Fully working chunked prefill e2e (#3884)
This commit is contained in:
parent
63e7176f26
commit
67b4221a61
@ -29,6 +29,8 @@ steps:
|
||||
- pytest -v -s test_pynccl.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
||||
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
||||
|
||||
@ -177,8 +177,7 @@ if __name__ == '__main__':
|
||||
help='block size of key/value cache')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
type=bool,
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
parser.add_argument(
|
||||
|
||||
@ -74,25 +74,31 @@ def run_vllm(
|
||||
quantization_param_path: Optional[str],
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_batched_tokens: int,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
download_dir: Optional[str] = None,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir)
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
for prompt, _, output_len in requests:
|
||||
@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager,
|
||||
args.kv_cache_dtype,
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching,
|
||||
args.gpu_memory_utilization, args.download_dir)
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
args.enforce_eager, args.kv_cache_dtype,
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.gpu_memory_utilization,
|
||||
args.download_dir)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -335,6 +341,14 @@ if __name__ == "__main__":
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
parser.add_argument("--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
help="enable chunked prefill for vLLM backend.")
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
|
||||
70
tests/basic_correctness/test_chunked_prefill.py
Normal file
70
tests/basic_correctness/test_chunked_prefill.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
It tests chunked prefill. Chunked prefill can be enabled by
|
||||
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
|
||||
prefill requests are chunked.
|
||||
|
||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
if (tensor_parallel_size == 2 and chunked_prefill_token_size != 16
|
||||
and not enforce_eager):
|
||||
pytest.skip(f"Skip {chunked_prefill_token_size=} and {enforce_eager=} "
|
||||
"for high TP to save testing time.")
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
print(vllm_outputs[0])
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
@ -104,10 +104,10 @@ def test_chunk():
|
||||
# One chunked prefill, and one decoding.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# The first one is decoding.
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
# The first one is prefill. Scheduler guarantees ordering.
|
||||
assert seq_group_meta[0].token_chunk_size == 56
|
||||
# The second one is a chunked prefill.
|
||||
assert seq_group_meta[1].token_chunk_size == 56
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 57
|
||||
|
||||
@ -157,12 +157,12 @@ def test_complex():
|
||||
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 3
|
||||
# The first one is decoding.
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
# The second one is a chunked prefill.
|
||||
# The first one is the first chunked prefill.
|
||||
assert seq_group_meta[0].token_chunk_size == 7
|
||||
# The second one is the second new chunked prefill.
|
||||
assert seq_group_meta[1].token_chunk_size == 56
|
||||
# The third one is also chunked.
|
||||
assert seq_group_meta[2].token_chunk_size == 7
|
||||
# The last one is decode.
|
||||
assert seq_group_meta[2].token_chunk_size == 1
|
||||
# Two of them are in chunked prefill.
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
@ -33,11 +33,16 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2)
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
||||
66
tests/distributed/test_chunked_prefill_distributed.py
Normal file
66
tests/distributed/test_chunked_prefill_distributed.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||
vLLM will allocate all the available memory, so we need to run the tests one
|
||||
by one. The solution is to pass arguments (model name) by environment
|
||||
variables.
|
||||
|
||||
Run:
|
||||
```sh
|
||||
TEST_DIST_MODEL=facebook/opt-125m pytest \
|
||||
test_chunked_prefill_distributed.py
|
||||
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
|
||||
test_chunked_prefill_distributed.py
|
||||
```
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
) -> None:
|
||||
# Add a chunked prefill config.
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
assert chunked_prefill_token_size != -1
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=max_num_seqs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
@ -141,7 +141,7 @@ def server(zephyr_lora_files):
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
"--max-num-seqs",
|
||||
"128"
|
||||
"128",
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
|
||||
@ -12,7 +12,7 @@ MODELS = [
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"bigscience/bloom-560m", # Testing alibi slopes.
|
||||
"microsoft/phi-2",
|
||||
"stabilityai/stablelm-3b-4e1t",
|
||||
# "allenai/OLMo-1B", # Broken
|
||||
|
||||
@ -1,14 +1,18 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_prompt(batch_size):
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(None, None, scheduler_config, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size):
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += prompt_len
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
_, _,
|
||||
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size):
|
||||
assert torch.allclose(attn_metadata.prompt_lens_tensor,
|
||||
torch.tensor(prompt_lens, device=device))
|
||||
assert attn_metadata.prompt_lens == prompt_lens
|
||||
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
|
||||
assert attn_metadata.num_generation_tokens == 0
|
||||
assert attn_metadata.max_prompt_len == max(prompt_lens)
|
||||
|
||||
# Test subquery start locs.
|
||||
@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size):
|
||||
assert torch.allclose(attn_metadata.block_tables, expected)
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
assert attn_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (sum(prompt_lens), )
|
||||
assert input_positions.shape == (sum(prompt_lens), )
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
assert input_tokens.shape == (sum(prompt_lens), )
|
||||
assert input_positions.shape == (sum(prompt_lens), )
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
assert input_tokens == input_positions
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
)
|
||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(model_config, None, scheduler_config, None,
|
||||
None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
input_tokens, input_positions, attn_metadata, _, _, _ = (
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is False
|
||||
assert attn_metadata.prompt_lens is None
|
||||
assert attn_metadata.num_prompt_tokens == 0
|
||||
assert attn_metadata.num_generation_tokens == expected_bs
|
||||
assert attn_metadata.max_prompt_len is None
|
||||
assert attn_metadata.subquery_start_loc is None
|
||||
assert attn_metadata.seq_start_loc is None
|
||||
@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
model_runner.get_max_block_per_batch())
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is True
|
||||
assert attn_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (expected_bs, )
|
||||
assert input_positions.shape == (expected_bs, )
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
assert len(input_tokens) == expected_bs
|
||||
assert len(input_positions) == expected_bs
|
||||
assert input_tokens == input_positions
|
||||
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_empty_seq_group():
|
||||
"""Verify prepare prompt and decode returns empty output."""
|
||||
model_config = ModelConfig(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
)
|
||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
seq_group_metadata_list = []
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_, _,
|
||||
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
assert len(return_prompt_lens) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
|
||||
|
||||
def get_world_size(group=None):
|
||||
return 1
|
||||
|
||||
def mock_get_process_group_ranks(group=None):
|
||||
return [0]
|
||||
|
||||
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
|
||||
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
|
||||
mock_get_process_group_ranks)
|
||||
|
||||
model_config = ModelConfig(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=True)
|
||||
model_runner = ModelRunner(model_config,
|
||||
None,
|
||||
scheduler_config,
|
||||
None,
|
||||
None,
|
||||
is_driver_worker=True)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
# Add prefill requests.
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
prefill_metadata_list = []
|
||||
decode_metadata_list = []
|
||||
block_tables = {0: [1]}
|
||||
prefill_batch_size = batch_size // 2
|
||||
decode_batch_size = batch_size - prefill_batch_size
|
||||
for i in range(prefill_batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = SequenceData(list(range(prompt_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
prefill_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# Add decode requests
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(prompt_len))
|
||||
seq_data = SequenceData(prompt_toks)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
decode_metadata_list.append(seq_group_metadata)
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, _, _, _,
|
||||
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||
decode_meta_actual = attn_metadata.decode_metadata
|
||||
|
||||
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
||||
assert len(input_positions) == len(input_tokens)
|
||||
assert attn_metadata.kv_cache_dtype == "auto"
|
||||
assert attn_metadata.num_prefills == prefill_batch_size
|
||||
if enforce_eager:
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
else:
|
||||
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
|
||||
decode_batch_size)
|
||||
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
prefill_meta = model_runner._prepare_prompt(
|
||||
prefill_metadata_list).attn_metadata
|
||||
decode_meta = model_runner._prepare_decode(
|
||||
decode_metadata_list).attn_metadata
|
||||
|
||||
for attr_expected, attr_actual in zip(vars(prefill_meta),
|
||||
vars(prefill_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
for attr_expected, attr_actual in zip(vars(decode_meta),
|
||||
vars(decode_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
@ -8,4 +9,5 @@ __all__ = [
|
||||
"AttentionMetadata",
|
||||
"Attention",
|
||||
"get_attn_backend",
|
||||
"AttentionMetadataPerStage",
|
||||
]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -47,7 +47,8 @@ class AttentionBackend(ABC):
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata:
|
||||
class AttentionMetadataPerStage:
|
||||
"""Attention metadata for a specific stage. I.e., prefill or decode."""
|
||||
|
||||
def asdict_zerocopy(self) -> Dict[str, Any]:
|
||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
||||
@ -59,6 +60,41 @@ class AttentionMetadata:
|
||||
}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadataPerStage)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata(Generic[T]):
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# Number of decode tokens. Note that it is equivalent to the number of
|
||||
# decode requests.
|
||||
num_decode_tokens: int
|
||||
# The attention metadata for prefill requests in a batch.
|
||||
# None if there's no prefill requests in a batch.
|
||||
prefill_metadata: Optional[T]
|
||||
# The attention metadata for decode requests in a batch.
|
||||
# None if there's no decode requests in a batch.
|
||||
decode_metadata: Optional[T]
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# The kv cache's data type.
|
||||
kv_cache_dtype: str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_prefill_tokens > 0:
|
||||
assert self.num_prefills > 0
|
||||
assert self.prefill_metadata is not None
|
||||
if self.num_decode_tokens > 0:
|
||||
assert self.decode_metadata is not None
|
||||
|
||||
|
||||
class AttentionImpl(ABC):
|
||||
|
||||
@abstractmethod
|
||||
@ -80,7 +116,7 @@ class AttentionImpl(ABC):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -11,7 +11,8 @@ import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -155,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -188,52 +195,70 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
output = flash_attn_varlen_func(
|
||||
out = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=attn_metadata.seq_start_loc,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_prompt_len,
|
||||
max_seqlen_k=attn_metadata.max_prompt_len,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prompt_len,
|
||||
max_seqlen_k=prefill_meta.max_prompt_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
# to be addressed separately.
|
||||
output = PagedAttention.forward_prefix(
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.prompt_lens_tensor,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_subquery_len,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
|
||||
@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -51,7 +52,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|
||||
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -181,7 +188,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -218,9 +225,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
kv_scale,
|
||||
)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
@ -230,63 +253,69 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||
if self.use_naive_attn:
|
||||
output = self.attn_fuc(
|
||||
out = self.attn_fuc(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_metadata.prompt_lens,
|
||||
prefill_meta.prompt_lens,
|
||||
self.scale,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
output, _ = self.attn_func(
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.max_prompt_len,
|
||||
attn_metadata.max_prompt_len,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prompt_len,
|
||||
prefill_meta.max_prompt_len,
|
||||
True,
|
||||
self.scale,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
output = self.attn_func(
|
||||
out = self.attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=attn_metadata.seq_start_loc,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_prompt_len,
|
||||
max_seqlen_k=attn_metadata.max_prompt_len,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prompt_len,
|
||||
max_seqlen_k=prefill_meta.max_prompt_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = PagedAttention.forward_prefix(
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.prompt_lens_tensor,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_subquery_len,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
@ -49,17 +50,14 @@ class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
"""Metadata for TorchSDPABackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
prompt_lens: Optional[List[int]]
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
num_prompt_tokens: int
|
||||
num_generation_tokens: int
|
||||
|
||||
max_subquery_len: Optional[int] = None
|
||||
max_prompt_len: Optional[int] = None
|
||||
@ -113,7 +111,7 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: TorchSDPAMetadata,
|
||||
attn_metadata: AttentionMetadata[TorchSDPAMetadata],
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with torch SDPA and PagedAttention.
|
||||
@ -142,36 +140,51 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
attn_metadata.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=1)
|
||||
|
||||
if attn_metadata.attn_bias is None:
|
||||
if prefill_meta.attn_bias is None:
|
||||
if self.alibi_slopes is not None:
|
||||
att_masks = _make_alibi_bias(
|
||||
self.alibi_slopes, query.dtype,
|
||||
attn_metadata.prompt_lens) # type: ignore
|
||||
prefill_meta.prompt_lens) # type: ignore
|
||||
elif self.sliding_window is not None:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
attn_metadata.prompt_lens, self.sliding_window,
|
||||
prefill_meta.prompt_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
att_masks = [None] * len(attn_metadata.prompt_lens)
|
||||
attn_metadata.attn_bias = att_masks
|
||||
att_masks = [None] * len(prefill_meta.prompt_lens)
|
||||
prefill_meta.attn_bias = att_masks
|
||||
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
start = 0
|
||||
output = torch.empty(
|
||||
(num_tokens, self.num_heads, self.head_size),
|
||||
dtype=query.dtype)
|
||||
for prompt_len, mask in zip(attn_metadata.prompt_lens,
|
||||
attn_metadata.attn_bias):
|
||||
out = torch.empty((num_tokens, self.num_heads, self.head_size),
|
||||
dtype=query.dtype)
|
||||
for prompt_len, mask in zip(prefill_meta.prompt_lens,
|
||||
prefill_meta.attn_bias):
|
||||
end = start + prompt_len
|
||||
sub_out = scaled_dot_product_attention(
|
||||
query[:, start:end, :],
|
||||
@ -181,28 +194,32 @@ class TorchSDPABackendImpl(AttentionImpl):
|
||||
dropout_p=0.0,
|
||||
is_causal=not self.need_mask,
|
||||
scale=self.scale).movedim(query.dim() - 2, 0)
|
||||
output[start:end, :, :] = sub_out
|
||||
out[start:end, :, :] = sub_out
|
||||
start = end
|
||||
assert out.shape == output[:num_prefill_tokens].shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise RuntimeError(
|
||||
"Torch SDPA backend doesn't support prefix decoding.")
|
||||
|
||||
else:
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
out = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
kv_scale,
|
||||
)
|
||||
assert out.shape == output[num_prefill_tokens:].shape
|
||||
output[num_prefill_tokens:]
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -54,7 +55,7 @@ class XFormersBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
"""Metadata for XFormersbackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
@ -123,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class XFormersImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens --------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -170,7 +171,7 @@ class XFormersImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_metadata: AttentionMetadata[XFormersMetadata],
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
@ -202,59 +203,61 @@ class XFormersImpl(AttentionImpl):
|
||||
attn_metadata.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention.
|
||||
# block tables are empty if the prompt does not have a cached
|
||||
# prefix.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0],
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
output = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, attn_metadata)
|
||||
out = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, prefill_meta)
|
||||
assert out.shape == output[:num_prefill_tokens].shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
# to be addressed separately.
|
||||
output = PagedAttention.forward_prefix(
|
||||
out = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.prompt_lens_tensor,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_subquery_len,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -275,13 +278,30 @@ class XFormersImpl(AttentionImpl):
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
|
||||
See https://facebookresearch.github.io/xformers/components/ops.html
|
||||
for API spec.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
query: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
"""
|
||||
original_query = query
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||
# Note that the output also has the same shape (which is different
|
||||
# from a spec from the doc).
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
@ -302,6 +322,7 @@ class XFormersImpl(AttentionImpl):
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
@ -312,14 +333,13 @@ class XFormersImpl(AttentionImpl):
|
||||
attn_bias=attn_metadata.attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
|
||||
return out.view_as(query)
|
||||
return out.view_as(original_query)
|
||||
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
output = torch.empty_like(query)
|
||||
output = torch.empty_like(original_query)
|
||||
start = 0
|
||||
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
@ -331,7 +351,7 @@ class XFormersImpl(AttentionImpl):
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.squeeze(0))
|
||||
output[start:end].copy_(out.view_as(original_query[start:end]))
|
||||
start += prompt_len
|
||||
return output
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ from typing import List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
|
||||
@ -41,7 +42,7 @@ class Attention(nn.Module):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||
|
||||
@ -13,11 +13,6 @@ _PARTITION_SIZE = 512
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# (batch_size,). The length of context (tokens stored in KV cache) per
|
||||
# sequence. WARNING: When it is a prefill request, it doesn't include new
|
||||
# tokens. When it is for decoding, it includes a new token.
|
||||
@ -31,7 +26,6 @@ class PagedAttentionMetadata:
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
kv_cache_dtype: str
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@ -565,9 +565,16 @@ class SchedulerConfig:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value for
|
||||
# higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
if enable_chunked_prefill:
|
||||
# For chunked prefill, choose the well-tuned batch size.
|
||||
self.max_num_batched_tokens = 768
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value
|
||||
# for higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
if enable_chunked_prefill:
|
||||
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
|
||||
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.use_v2_block_manager = use_v2_block_manager
|
||||
|
||||
@ -140,7 +140,11 @@ class SchedulerOutputs:
|
||||
|
||||
@property
|
||||
def lora_requests(self) -> Set[LoRARequest]:
|
||||
return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
|
||||
return {
|
||||
g.seq_group.lora_request
|
||||
for g in self.scheduled_seq_groups
|
||||
if g.seq_group.lora_request is not None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -826,13 +830,12 @@ class Scheduler:
|
||||
# Update swapped requests.
|
||||
self.swapped = remaining_swapped
|
||||
self.swapped.extend(running_scheduled.swapped_out)
|
||||
|
||||
return SchedulerOutputs(
|
||||
scheduled_seq_groups=(prefills.seq_groups +
|
||||
running_scheduled.decode_seq_groups +
|
||||
running_scheduled.prefill_seq_groups +
|
||||
swapped_in.decode_seq_groups +
|
||||
swapped_in.prefill_seq_groups),
|
||||
swapped_in.prefill_seq_groups +
|
||||
running_scheduled.decode_seq_groups +
|
||||
swapped_in.decode_seq_groups),
|
||||
num_prefill_groups=(len(prefills.seq_groups) +
|
||||
len(swapped_in.prefill_seq_groups) +
|
||||
len(running_scheduled.prefill_seq_groups)),
|
||||
@ -907,7 +910,7 @@ class Scheduler:
|
||||
|
||||
# It assumes the scheduled_seq_groups is ordered by
|
||||
# prefill < decoding.
|
||||
is_prompt = i < scheduler_outputs.num_prefill_groups
|
||||
is_prompt = seq_group.is_prefill()
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=is_prompt,
|
||||
|
||||
@ -173,10 +173,18 @@ def broadcast_tensor_dict(
|
||||
torch.distributed.broadcast_object_list([metadata_list],
|
||||
src=src,
|
||||
group=group)
|
||||
async_handles = []
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = tensor_dict[key]
|
||||
torch.distributed.broadcast(tensor, src=src, group=group)
|
||||
async_handles.append(
|
||||
torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
group=group,
|
||||
async_op=True))
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
|
||||
else:
|
||||
recv_metadata_list = [None]
|
||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||
|
||||
@ -386,9 +386,8 @@ class EngineArgs:
|
||||
'prompt latency) before scheduling next prompt.')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
action='store_true',
|
||||
help='If set, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@ -633,7 +633,10 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
# If uncomputed tokens > 0, it means prefill is chunked.
|
||||
# We don't need to process outputs in that case.
|
||||
if seq_group.get_num_uncomputed_tokens() == 0:
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
|
||||
@ -267,12 +267,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
|
||||
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
|
||||
embedding_len = self.indices_len[3]
|
||||
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
|
||||
indices = self.embeddings_indices[0][:embedding_len].view_as(x)
|
||||
full_output = self.base_layer.forward(
|
||||
x.add_(indices * added_tokens_mask))
|
||||
|
||||
|
||||
@ -500,7 +500,8 @@ class SequenceGroup:
|
||||
def get_num_uncomputed_tokens(self) -> int:
|
||||
num_uncomputed_tokens = 0
|
||||
for seq in self.get_seqs():
|
||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||
if not seq.is_finished():
|
||||
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
|
||||
return num_uncomputed_tokens
|
||||
|
||||
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
import contextlib
|
||||
import time
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||
get_attn_backend)
|
||||
from vllm.config import (DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
||||
@ -37,6 +39,66 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
]
|
||||
|
||||
|
||||
class PreparePromptMetadata(NamedTuple):
|
||||
input_tokens: List[int]
|
||||
input_positions: List[int]
|
||||
attn_metadata: Optional[AttentionMetadataPerStage]
|
||||
prompt_lens: List[int]
|
||||
subquery_lens: List[int]
|
||||
lora_index_mapping: List[int]
|
||||
lora_prompt_mapping: List[int]
|
||||
lora_requests: Set[LoRARequest]
|
||||
multi_modal_input: Optional[torch.Tensor]
|
||||
slot_mapping: List[int]
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return PreparePromptMetadata(
|
||||
input_tokens=[],
|
||||
input_positions=[],
|
||||
attn_metadata=None,
|
||||
prompt_lens=[],
|
||||
subquery_lens=[],
|
||||
lora_index_mapping=[],
|
||||
lora_prompt_mapping=[],
|
||||
lora_requests=set(),
|
||||
multi_modal_input=None,
|
||||
slot_mapping=[],
|
||||
)
|
||||
|
||||
|
||||
class PrepareDecodeMetadata(NamedTuple):
|
||||
input_tokens: List[int]
|
||||
input_positions: List[int]
|
||||
attn_metadata: Optional[AttentionMetadata]
|
||||
lora_index_mapping: List[int]
|
||||
lora_prompt_mapping: List[int]
|
||||
lora_requests: Set[LoRARequest]
|
||||
slot_mapping: List[int]
|
||||
|
||||
@classmethod
|
||||
def empty(cls):
|
||||
return PrepareDecodeMetadata(
|
||||
input_tokens=[],
|
||||
input_positions=[],
|
||||
attn_metadata=None,
|
||||
lora_index_mapping=[],
|
||||
lora_prompt_mapping=[],
|
||||
lora_requests=set(),
|
||||
slot_mapping=[],
|
||||
)
|
||||
|
||||
|
||||
# How batches are constructed.
|
||||
class BatchType(IntEnum):
|
||||
# Every batch is prefill.
|
||||
PREFILL = 0
|
||||
# Every batch is decode.
|
||||
DECODE = 1
|
||||
# Batch is a mixture of prefill and decode.
|
||||
MIXED = 2
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
def __init__(
|
||||
@ -152,10 +214,7 @@ class ModelRunner:
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
List[int], List[int], List[int], Set[LoRARequest],
|
||||
torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
) -> PreparePromptMetadata:
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
@ -169,6 +228,9 @@ class ModelRunner:
|
||||
prefix_block_tables: List[List[int]] = []
|
||||
multi_modal_input_list: List[torch.Tensor] = []
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return PreparePromptMetadata.empty()
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
@ -178,7 +240,8 @@ class ModelRunner:
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
if (self.scheduler_config is not None
|
||||
and self.scheduler_config.chunked_prefill_enabled
|
||||
and computed_block_nums is not None):
|
||||
and not (computed_block_nums is None
|
||||
or computed_block_nums == [])):
|
||||
raise RuntimeError(
|
||||
"chunked prefill cannot be used with prefix caching "
|
||||
"now.")
|
||||
@ -190,13 +253,8 @@ class ModelRunner:
|
||||
# it contains output tokens.
|
||||
prefill_end = min(seq_data.get_len(),
|
||||
computed_len + token_chunk_size)
|
||||
# TODO(sang): Rename it after chunked prefill is introduced.
|
||||
prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end]
|
||||
prompt_len = len(prompt_tokens)
|
||||
# Right now, the prefill_end is always same as the length of
|
||||
# sequence. However, once chunked prefill is introduced, this
|
||||
# assumption can be changed.
|
||||
assert prefill_end == seq_data.get_len()
|
||||
prompt_len = prefill_end
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
@ -206,6 +264,14 @@ class ModelRunner:
|
||||
computed_len = len(computed_block_nums) * self.block_size
|
||||
prompt_tokens = prompt_tokens[computed_len:]
|
||||
prefix_block_tables.append(computed_block_nums)
|
||||
elif self.scheduler_config.chunked_prefill_enabled:
|
||||
if seq_group_metadata.block_tables is not None:
|
||||
# Prefill has chunked before.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
prefix_block_tables.append(block_table)
|
||||
else:
|
||||
# The first prefill.
|
||||
prefix_block_tables.append([])
|
||||
else:
|
||||
prefix_block_tables.append([])
|
||||
# Right now, prefill start is always 0. However, this
|
||||
@ -267,20 +333,8 @@ class ModelRunner:
|
||||
|
||||
max_subquery_len = max(subquery_lens)
|
||||
max_prompt_len = max(prompt_lens)
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
assert max_subquery_len > 0
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
lora_index_mapping = lora_index_mapping
|
||||
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
@ -332,11 +386,8 @@ class ModelRunner:
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=prompt_lens,
|
||||
prompt_lens_tensor=prompt_lens_tensor,
|
||||
num_prompt_tokens=num_prompt_tokens,
|
||||
num_generation_tokens=0,
|
||||
max_subquery_len=max_subquery_len,
|
||||
max_context_len=None,
|
||||
max_prompt_len=max_prompt_len,
|
||||
@ -345,18 +396,25 @@ class ModelRunner:
|
||||
context_lens=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||
lora_requests, multi_modal_input)
|
||||
|
||||
return PreparePromptMetadata(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
prompt_lens=prompt_lens,
|
||||
subquery_lens=subquery_lens,
|
||||
lora_index_mapping=lora_index_mapping,
|
||||
lora_prompt_mapping=lora_prompt_mapping,
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_input=multi_modal_input,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
List[int], Set[LoRARequest]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
) -> PrepareDecodeMetadata:
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
@ -366,6 +424,9 @@ class ModelRunner:
|
||||
lora_prompt_mapping: List[int] = []
|
||||
lora_requests: Set[LoRARequest] = set()
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return PrepareDecodeMetadata.empty()
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
@ -424,15 +485,6 @@ class ModelRunner:
|
||||
lora_index_mapping.append(0)
|
||||
batch_size = graph_batch_size
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
@ -440,9 +492,9 @@ class ModelRunner:
|
||||
if use_captured_graph:
|
||||
# When using cuda-graph all these tensors should be
|
||||
# padded.
|
||||
assert context_lens.shape[0] == input_tokens.shape[0]
|
||||
assert context_lens.shape[0] == input_positions.shape[0]
|
||||
assert context_lens.shape[0] == slot_mapping.shape[0]
|
||||
assert context_lens.shape[0] == len(input_tokens)
|
||||
assert context_lens.shape[0] == len(input_positions)
|
||||
assert context_lens.shape[0] == len(slot_mapping)
|
||||
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
@ -464,11 +516,8 @@ class ModelRunner:
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=None,
|
||||
prompt_lens_tensor=None,
|
||||
num_prompt_tokens=0,
|
||||
num_generation_tokens=len(input_tokens),
|
||||
max_subquery_len=None,
|
||||
max_context_len=max_context_len,
|
||||
max_prompt_len=None,
|
||||
@ -477,10 +526,16 @@ class ModelRunner:
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
lora_index_mapping, lora_prompt_mapping, lora_requests)
|
||||
return PrepareDecodeMetadata(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
lora_index_mapping=lora_index_mapping,
|
||||
lora_prompt_mapping=lora_prompt_mapping,
|
||||
lora_requests=lora_requests,
|
||||
slot_mapping=slot_mapping,
|
||||
)
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
@ -586,26 +641,66 @@ class ModelRunner:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Set[int], LoRAMapping, torch.Tensor]:
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
prefill_reqs = []
|
||||
decode_reqs = []
|
||||
for seq_group_meta in seq_group_metadata_list:
|
||||
if seq_group_meta.is_prompt:
|
||||
prefill_reqs.append(seq_group_meta)
|
||||
else:
|
||||
decode_reqs.append(seq_group_meta)
|
||||
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, prompt_lens,
|
||||
subquery_lens, lora_index_mapping, lora_prompt_mapping,
|
||||
lora_requests, multi_modal_input
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions, attn_metadata,
|
||||
lora_index_mapping, lora_prompt_mapping,
|
||||
lora_requests) = self._prepare_decode(seq_group_metadata_list)
|
||||
prompt_lens = []
|
||||
subquery_lens = None
|
||||
multi_modal_input = None
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
prefill_attn_metadata,
|
||||
prompt_lens,
|
||||
subquery_lens,
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
lora_requests,
|
||||
multi_modal_input,
|
||||
slot_mapping,
|
||||
) = self._prepare_prompt(prefill_reqs)
|
||||
(
|
||||
decode_input_tokens,
|
||||
decode_input_positions,
|
||||
decode_attn_metadata,
|
||||
decode_lora_index_mapping,
|
||||
decode_lora_prompt_mapping,
|
||||
decode_lora_requests,
|
||||
decode_slot_mapping,
|
||||
) = self._prepare_decode(decode_reqs)
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens)
|
||||
|
||||
if not self.scheduler_config.chunked_prefill_enabled:
|
||||
assert (len(prefill_reqs) and len(decode_reqs)) == 0
|
||||
|
||||
num_prefills = len(prompt_lens)
|
||||
num_prefill_tokens = len(input_tokens)
|
||||
num_decode_tokens = len(decode_input_tokens)
|
||||
|
||||
# Coalesce tensors. Note that attn_metadata is currently not
|
||||
# coalesced for simplicity.
|
||||
input_tokens.extend(decode_input_tokens)
|
||||
input_positions.extend(decode_input_positions)
|
||||
slot_mapping.extend(decode_slot_mapping)
|
||||
lora_index_mapping.extend(decode_lora_index_mapping)
|
||||
lora_prompt_mapping.extend(decode_lora_prompt_mapping)
|
||||
lora_requests.update(decode_lora_requests)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
lora_index_mapping,
|
||||
@ -615,6 +710,16 @@ class ModelRunner:
|
||||
lora_mapping = None
|
||||
|
||||
# Broadcast the metadata.
|
||||
# If batch contains both prefill and decode, it sends 2 broadcasts.
|
||||
# If it only contains 1 type, it triggers a single broadcast.
|
||||
if (prefill_attn_metadata is not None
|
||||
and decode_attn_metadata is not None):
|
||||
batch_type = BatchType.MIXED
|
||||
elif prefill_attn_metadata is not None:
|
||||
batch_type = BatchType.PREFILL
|
||||
else:
|
||||
batch_type = BatchType.DECODE
|
||||
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
@ -623,19 +728,49 @@ class ModelRunner:
|
||||
"lora_requests": lora_requests,
|
||||
"lora_mapping": lora_mapping,
|
||||
"multi_modal_input": multi_modal_input,
|
||||
"num_prefill_tokens": num_prefill_tokens,
|
||||
"num_decode_tokens": num_decode_tokens,
|
||||
"slot_mapping": slot_mapping,
|
||||
"num_prefills": num_prefills,
|
||||
"batch_type": batch_type,
|
||||
}
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
if prefill_attn_metadata is not None:
|
||||
metadata_dict.update(prefill_attn_metadata.asdict_zerocopy())
|
||||
else:
|
||||
metadata_dict.update(decode_attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
|
||||
# Broadcast decode attn metadata for mixed batch type.
|
||||
# The additional broadcast costs 300us overhead on 4 A10 GPUs.
|
||||
# We can potentially reduce the overhead by coelescing tensors.
|
||||
if batch_type == BatchType.MIXED:
|
||||
assert decode_attn_metadata is not None
|
||||
metadata_dict = decode_attn_metadata.asdict_zerocopy()
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
slot_mapping = metadata_dict.pop("slot_mapping")
|
||||
num_prefills = metadata_dict.pop("num_prefills")
|
||||
selected_token_indices = metadata_dict.pop(
|
||||
"selected_token_indices")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
multi_modal_input = metadata_dict.pop("multi_modal_input")
|
||||
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
||||
num_prefill_tokens = metadata_dict.pop("num_prefill_tokens")
|
||||
num_decode_tokens = metadata_dict.pop("num_decode_tokens")
|
||||
batch_type = metadata_dict.pop("batch_type")
|
||||
|
||||
# Create an attention metadata.
|
||||
prefill_attn_metadata = None
|
||||
decode_attn_metadata = None
|
||||
if batch_type == BatchType.PREFILL or batch_type == BatchType.MIXED:
|
||||
prefill_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
else:
|
||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
seq_data=None,
|
||||
@ -646,6 +781,23 @@ class ModelRunner:
|
||||
perform_sampling=False,
|
||||
)
|
||||
|
||||
# if it is a mixed batch, decode attn_metadata is broadcasted
|
||||
# separately.
|
||||
if batch_type == BatchType.MIXED:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
decode_attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
slot_mapping=slot_mapping,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
prefill_metadata=prefill_attn_metadata,
|
||||
decode_metadata=decode_attn_metadata,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, lora_requests, lora_mapping,
|
||||
multi_modal_input)
|
||||
@ -663,8 +815,10 @@ class ModelRunner:
|
||||
if self.lora_config:
|
||||
self.set_active_loras(lora_requests, lora_mapping)
|
||||
|
||||
# Execute the model.
|
||||
if attn_metadata.use_cuda_graph:
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
prefill_meta = attn_metadata.prefill_metadata
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
graph_batch_size = input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
@ -842,13 +996,10 @@ class ModelRunner:
|
||||
# memory usage of CUDA graph.
|
||||
for batch_size in reversed(batch_size_capture_list):
|
||||
# Create dummy attn_metadata.
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
decode_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
prompt_lens=None,
|
||||
prompt_lens_tensor=None,
|
||||
num_prompt_tokens=0,
|
||||
num_generation_tokens=batch_size,
|
||||
max_subquery_len=None,
|
||||
max_context_len=self.max_context_len_to_capture,
|
||||
max_prompt_len=None,
|
||||
@ -857,6 +1008,14 @@ class ModelRunner:
|
||||
context_lens=context_lens[:batch_size],
|
||||
block_tables=block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=slot_mapping[:batch_size],
|
||||
prefill_metadata=None,
|
||||
decode_metadata=decode_metadata,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
|
||||
@ -950,8 +1109,8 @@ class CUDAGraphRunner:
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"context_lens": attn_metadata.context_lens,
|
||||
"block_tables": attn_metadata.block_tables,
|
||||
"context_lens": attn_metadata.decode_metadata.context_lens,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
self.output_buffers = {"hidden_states": hidden_states}
|
||||
return
|
||||
@ -972,10 +1131,10 @@ class CUDAGraphRunner:
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
|
||||
non_blocking=True)
|
||||
self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
|
||||
non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
|
||||
non_blocking=True)
|
||||
self.input_buffers["context_lens"].copy_(
|
||||
attn_metadata.decode_metadata.context_lens, non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user