[1/n][Chunked Prefill] Refactor input query shapes (#3236)

This commit is contained in:
SangBin Cho 2024-03-21 06:46:05 +09:00 committed by GitHub
parent 426ec4ec67
commit 6e435de766
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 579 additions and 263 deletions

View File

@ -47,7 +47,7 @@ steps:
- pytest -v -s prefix_caching - pytest -v -s prefix_caching
- label: Samplers Test - label: Samplers Test
command: pytest -v -s samplers --forked command: pytest -v -s samplers
- label: Worker Test - label: Worker Test
command: pytest -v -s worker command: pytest -v -s worker
@ -56,7 +56,7 @@ steps:
command: pytest -v -s spec_decode command: pytest -v -s spec_decode
- label: LoRA Test %N - label: LoRA Test %N
command: pytest -v -s lora --forked --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4 parallelism: 4
- label: Metrics Test - label: Metrics Test

View File

@ -13,6 +13,7 @@ MODELS = [
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models( def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
@ -20,12 +21,13 @@ def test_models(
model: str, model: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool,
) -> None: ) -> None:
hf_model = hf_runner(model, dtype=dtype) hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model del hf_model
vllm_model = vllm_runner(model, dtype=dtype) vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model del vllm_model

View File

@ -10,7 +10,7 @@ from .utils import create_dummy_prompt
def test_scheduler_add_seq_group(): def test_scheduler_add_seq_group():
block_size = 4 block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1, 256) scheduler_config = SchedulerConfig(100, 64, 1)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 4 cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4 cache_config.num_gpu_blocks = 4
@ -26,7 +26,7 @@ def test_scheduler_add_seq_group():
def test_scheduler_abort_seq_group(): def test_scheduler_abort_seq_group():
block_size = 4 block_size = 4
scheduler_config = SchedulerConfig(100, 64, 1, 256) scheduler_config = SchedulerConfig(100, 64, 1)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 4 cache_config.num_cpu_blocks = 4
cache_config.num_gpu_blocks = 4 cache_config.num_gpu_blocks = 4
@ -50,7 +50,7 @@ def test_scheduler_schedule_simple():
block_size = 4 block_size = 4
num_seq_group = 4 num_seq_group = 4
max_model_len = 16 max_model_len = 16
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len, 256) scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 8
@ -64,10 +64,10 @@ def test_scheduler_schedule_simple():
running.append(seq_group) running.append(seq_group)
# Schedule seq groups prompts. # Schedule seq groups prompts.
num_tokens = block_size * num_seq_group
seq_group_meta, out = scheduler.schedule() seq_group_meta, out = scheduler.schedule()
assert set(out.scheduled_seq_groups) == set(running) assert set(out.scheduled_seq_groups) == set(running)
assert out.num_batched_tokens == num_seq_group * seq_group.get_seqs( assert out.num_batched_tokens == num_tokens
)[0].get_len()
assert (not out.blocks_to_copy and not out.blocks_to_swap_in assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out) and not out.blocks_to_swap_out)
assert len(seq_group_meta) == num_seq_group assert len(seq_group_meta) == num_seq_group
@ -84,7 +84,7 @@ def test_scheduler_schedule_simple():
def test_scheduler_schedule_preempt_abort(): def test_scheduler_schedule_preempt_abort():
block_size = 4 block_size = 4
max_model_len = 16 max_model_len = 16
scheduler_config = SchedulerConfig(64, 2, max_model_len, 256) scheduler_config = SchedulerConfig(64, 2, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 2 cache_config.num_cpu_blocks = 2
cache_config.num_gpu_blocks = 2 cache_config.num_gpu_blocks = 2
@ -99,7 +99,7 @@ def test_scheduler_schedule_preempt_abort():
# Schedule seq groups prompts. # Schedule seq groups prompts.
seq_group_meta, out = scheduler.schedule() seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_a, seq_group_b] assert out.scheduled_seq_groups == [seq_group_a, seq_group_b]
assert out.num_batched_tokens == seq_group_a.get_seqs()[0].get_len() * 2 assert out.num_batched_tokens == block_size * 2 # seq_a and seq_b
assert (not out.blocks_to_copy and not out.blocks_to_swap_in assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out) and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 2 assert len(seq_group_meta) == 2
@ -124,7 +124,7 @@ def test_scheduler_schedule_preempt_abort():
scheduler.abort_seq_group("1") scheduler.abort_seq_group("1")
seq_group_meta, out = scheduler.schedule() seq_group_meta, out = scheduler.schedule()
assert out.scheduled_seq_groups == [seq_group_b] assert out.scheduled_seq_groups == [seq_group_b]
assert out.num_batched_tokens == seq_group_b.get_seqs()[0].get_len() assert out.num_batched_tokens == 5 # 4 prompt + 1 generation.
assert (not out.blocks_to_copy and not out.blocks_to_swap_in assert (not out.blocks_to_copy and not out.blocks_to_swap_in
and not out.blocks_to_swap_out) and not out.blocks_to_swap_out)
assert len(seq_group_meta) == 1 assert len(seq_group_meta) == 1
@ -136,7 +136,7 @@ def test_scheduler_max_seqs():
num_seq_group = 4 num_seq_group = 4
max_seq_group = 2 max_seq_group = 2
max_model_len = 16 max_model_len = 16
scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len, 256) scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len)
cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8 cache_config.num_cpu_blocks = 8
cache_config.num_gpu_blocks = 8 cache_config.num_gpu_blocks = 8

View File

@ -25,7 +25,7 @@ def test_worker_apply_lora(sql_lora_files):
revision=None, revision=None,
), ),
parallel_config=ParallelConfig(1, 1, False), parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32, 256), scheduler_config=SchedulerConfig(32, 32, 32),
device_config=DeviceConfig("cuda"), device_config=DeviceConfig("cuda"),
local_rank=0, local_rank=0,
rank=0, rank=0,

View File

@ -92,8 +92,8 @@ def test_same_output_for_single_step():
num_gpu_blocks, num_gpu_blocks,
seed, seed,
) )
multi_step_worker.model_runner = worker.model_runner # multi_step_worker.model_runner = worker.model_runner
multi_step_worker.cache_engine = worker.cache_engine # multi_step_worker.cache_engine = worker.cache_engine
num_steps = 1 num_steps = 1

View File

@ -1,14 +1,132 @@
import random import random
import torch import torch
from vllm.config import ModelConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
def get_aligned_size(batch_size: int, alignment: int):
return ((batch_size + alignment - 1) // alignment * alignment)
def test_prepare_prompt(): def test_prepare_prompt():
model_runner = ModelRunner(None, None, None, None, None) model_runner = ModelRunner(None, None, None, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
block_tables = {0: [1]}
for i in range(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
))
expected_selected_token_indices = []
selected_token_start_idx = 0
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += prompt_len
(input_tokens, input_positions, input_metadata, return_prompt_lens, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is True
assert torch.allclose(input_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device))
assert input_metadata.prompt_lens == prompt_lens
assert input_metadata.num_prompt_tokens == sum(prompt_lens)
assert input_metadata.num_generation_tokens == 0
assert input_metadata.max_seq_len == max(prompt_lens)
# Test subquery start locs.
start_idx = 0
start_loc = [start_idx]
for prompt_len in prompt_lens:
start_idx += prompt_len
start_loc.append(start_idx)
assert torch.allclose(
input_metadata.subquery_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
# Test seq start locs. Note that for normal prefill it is
# equivalent to subquery_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for prompt_len in prompt_lens:
start_idx += prompt_len
seq_start_loc.append(start_idx)
assert torch.allclose(
input_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device))
assert input_metadata.max_context_len is None
assert torch.allclose(
input_metadata.context_lens,
torch.zeros(input_metadata.context_lens.shape[0],
dtype=torch.int,
device=device))
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device)
assert torch.allclose(input_metadata.block_tables, expected)
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is False
assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens,
subquery_lens=prompt_lens)
assert input_tokens.shape == (sum(prompt_lens), )
assert input_positions.shape == (sum(prompt_lens), )
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
torch.testing.assert_close(input_tokens, input_positions)
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
torch.testing.assert_close(actual, expected)
def test_prepare_decode_cuda_graph():
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
prompt_lens = [] prompt_lens = []
seq_group_metadata_list = [] seq_group_metadata_list = []
@ -20,29 +138,56 @@ def test_prepare_prompt():
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=False,
seq_data={0: SequenceData(seq_data)}, seq_data={0: SequenceData(seq_data)},
sampling_params=SamplingParams(temperature=0), sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]}, block_tables={0: [1]},
)) ))
input_tokens, input_positions, input_metadata, _, _, _ = (
model_runner._prepare_decode(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert input_metadata.is_prompt is False
assert input_metadata.prompt_lens is None
assert input_metadata.num_prompt_tokens == 0
assert input_metadata.num_generation_tokens == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
assert input_metadata.max_seq_len is None
assert input_metadata.subquery_start_loc is None
assert input_metadata.seq_start_loc is None
assert input_metadata.max_context_len == max(prompt_lens)
assert torch.allclose(
input_metadata.context_lens[:len(prompt_lens)],
torch.tensor(prompt_lens, dtype=torch.int, device=device))
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert input_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number.
# It is padded up to
assert input_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill.
assert input_metadata.use_cuda_graph is True
assert input_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
assert input_positions.shape == (get_aligned_size(
len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
torch.testing.assert_close(input_tokens, input_positions)
# Verify Sampling
expected_selected_token_indices = [] expected_selected_token_indices = []
selected_token_start_idx = 0 selected_token_start_idx = 0
max_seq_len = max(prompt_lens)
for prompt_len in prompt_lens: for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx)
prompt_len - 1) selected_token_start_idx += 1
selected_token_start_idx += max_seq_len
input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = (
model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens=prompt_lens) subquery_lens=prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions)
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
device=actual.device, device=actual.device,

View File

@ -535,7 +535,6 @@ class SchedulerConfig:
iteration. iteration.
max_model_len: Maximum length of a sequence (including prompt max_model_len: Maximum length of a sequence (including prompt
and generated text). and generated text).
max_paddings: Maximum number of paddings to be added to a batch.
""" """
def __init__( def __init__(
@ -543,7 +542,6 @@ class SchedulerConfig:
max_num_batched_tokens: Optional[int], max_num_batched_tokens: Optional[int],
max_num_seqs: int, max_num_seqs: int,
max_model_len: int, max_model_len: int,
max_paddings: int,
) -> None: ) -> None:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
@ -553,7 +551,6 @@ class SchedulerConfig:
self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_batched_tokens = max(max_model_len, 2048)
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.max_paddings = max_paddings
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:

View File

@ -173,12 +173,12 @@ class Scheduler:
curr_loras = set( curr_loras = set(
seq_group.lora_int_id seq_group.lora_int_id
for seq_group in self.running) if self.lora_enabled else None for seq_group in self.running) if self.lora_enabled else None
seq_lens: List[int] = []
# Optimization: We do not sort the waiting queue since the preempted # Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups # sequence groups are added to the front and the new sequence groups
# are added to the back. # are added to the back.
leftover_waiting_sequences = deque() leftover_waiting_sequences = deque()
num_batched_tokens = 0
while self.waiting: while self.waiting:
seq_group = self.waiting[0] seq_group = self.waiting[0]
waiting_seqs = seq_group.get_seqs( waiting_seqs = seq_group.get_seqs(
@ -223,8 +223,7 @@ class Scheduler:
continue continue
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
new_seq_lens = seq_lens + [num_prompt_tokens] num_batched_tokens += num_prompt_tokens
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
if (num_batched_tokens > if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens): self.scheduler_config.max_num_batched_tokens):
break break
@ -236,11 +235,6 @@ class Scheduler:
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
num_paddings = num_batched_tokens - sum(new_seq_lens)
if num_paddings > self.scheduler_config.max_paddings:
break
seq_lens = new_seq_lens
if lora_int_id > 0: if lora_int_id > 0:
curr_loras.add(lora_int_id) curr_loras.add(lora_int_id)
self.waiting.popleft() self.waiting.popleft()
@ -255,8 +249,7 @@ class Scheduler:
scheduler_outputs = SchedulerOutputs( scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled, scheduled_seq_groups=scheduled,
prompt_run=True, prompt_run=True,
num_batched_tokens=len(seq_lens) * num_batched_tokens=num_batched_tokens,
max(seq_lens) if seq_lens else 0,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,

View File

@ -31,7 +31,6 @@ class EngineArgs:
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256 max_num_seqs: int = 256
max_paddings: int = 256
max_logprobs: int = 5 # OpenAI default value max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
@ -213,10 +212,6 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='maximum number of sequences per iteration')
parser.add_argument('--max-paddings',
type=int,
default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch')
parser.add_argument( parser.add_argument(
'--max-logprobs', '--max-logprobs',
type=int, type=int,
@ -347,8 +342,7 @@ class EngineArgs:
), self.ray_workers_use_nsight) ), self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens, scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs, self.max_num_seqs,
model_config.max_model_len, model_config.max_model_len)
self.max_paddings)
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,

View File

@ -561,7 +561,6 @@ class LLMEngine:
# Log stats. # Log stats.
if self.log_stats: if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs)) self.stat_logger.log(self._get_stats(scheduler_outputs))
return request_outputs return request_outputs
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:

View File

@ -1,36 +1,92 @@
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Optional, Any, Dict from typing import Optional, List, Any, Dict
import torch import torch
from xformers.ops.fmha.attn_bias import AttentionBias
@dataclass @dataclass
class InputMetadata: class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention. """Metadata for input sequences. Used in PagedAttention.
Args: NOTE: Any python object stored here is not updated when it is
prompt_lens: Lengths of prompts. cuda-graph replayed. If you have values that need to be changed
slot_mapping: The address to write the new KV to of each token. dynamically, it should be stored in tensor. The tensor has to be
max_context_len: The maximum context length. updated from `CUDAGraphRunner.forward` API.
context_lens: the length of attention context for each sequence. """
block_tables: The block tables. (Seq id -> list of physical block) # Currently, input sequences can only contain all prompts
kv_cache_dtype: Data type to store kv cache. # or all decoding. True if all sequences are prompts.
is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
"""
Definition of context_len, subquery_len, and seqlen.
|---------- N-1 iteration --------|
|---------------- N iteration ---------------------|
|- tokenA -|......................|-- newTokens ---|
|---------- context_len ----------|
|-------------------- seqlen ----------------------|
|- subquery_len -|
WARNING: context_len has different definition depending on if it is
prefill vs decoding. When it is prefill, it doesn't include new
tokens. When it is for decoding, it includes a new token.
""" """
is_prompt: bool # Maximum subquery length in the batch.
slot_mapping: torch.Tensor max_subquery_len: Optional[int]
prompt_lens: Optional[torch.Tensor] # Maximum context length in the batch.
max_seq_len: Optional[int]
start_loc: Optional[torch.Tensor]
max_context_len: Optional[int] max_context_len: Optional[int]
# FIXME: It is for flash attn.
# Maximum sequence length in the batch.
max_seq_len: Optional[int]
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
subquery_start_loc: Optional[torch.Tensor]
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor]
# (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token.
context_lens: Optional[torch.Tensor] context_lens: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor] block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
use_cuda_graph: bool use_cuda_graph: bool
kv_cache_dtype: str kv_cache_dtype: str
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__ # will not appear in the __repr__ and __init__
self.attn_bias = None self.attn_bias: Optional[List[AttentionBias]] = None
# Cuda graph is only used for decoding now.
if self.use_cuda_graph:
assert self.num_prompt_tokens == 0
def asdict_zerocopy(self) -> Dict[str, Any]: def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying.""" """Similar to dataclasses.asdict, but avoids deepcopying."""

View File

@ -20,8 +20,8 @@ class SiluAndMul(nn.Module):
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes: Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d) return: (num_tokens, d) or (batch_size, seq_len, d)
""" """
def _forward(self, x: torch.Tensor) -> torch.Tensor: def _forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@ -17,11 +17,12 @@ class Attention(nn.Module):
This class takes query, key, and value tensors as input. The input tensors This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens. can either contain prompt tokens or generation tokens.
The class does the following: The class does the following:
1. Store the input key and value tensors in the KV cache. 1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention. 2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor. 3. Output the output tensor.
""" """
def __init__( def __init__(

View File

@ -1,7 +1,7 @@
"""Attention layer with Flash and PagedAttention.""" """Attention layer with Flash and PagedAttention."""
from typing import List, Optional from typing import List, Optional
from flash_attn import flash_attn_func from flash_attn import flash_attn_varlen_func
import torch import torch
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
@ -10,6 +10,21 @@ from vllm.model_executor.layers.attention.ops.paged_attn import (
class FlashAttentionBackend: class FlashAttentionBackend:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__( def __init__(
self, self,
@ -52,18 +67,18 @@ class FlashAttentionBackend:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
Args: Args:
query: shape = [batch_size, seq_len, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size, value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size] block_size]
input_metadata: metadata for the inputs. input_metadata: metadata for the inputs.
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
batch_size, seq_len, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
@ -82,13 +97,16 @@ class FlashAttentionBackend:
if (key_cache is None or value_cache is None if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0): or input_metadata.block_tables.numel() == 0):
# normal attention # normal attention
query = query.unflatten(0, (batch_size, seq_len)) # When block_tables are not filled, it means q and k are the
key = key.unflatten(0, (batch_size, seq_len)) # prompt, and they have the same length.
value = value.unflatten(0, (batch_size, seq_len)) output = flash_attn_varlen_func(
output = flash_attn_func( q=query,
query, k=key,
key, v=value,
value, cu_seqlens_q=input_metadata.seq_start_loc,
cu_seqlens_k=input_metadata.seq_start_loc,
max_seqlen_q=input_metadata.max_seq_len,
max_seqlen_k=input_metadata.max_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
@ -118,4 +136,4 @@ class FlashAttentionBackend:
) )
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(num_tokens, hidden_size)

View File

@ -14,6 +14,21 @@ from vllm.utils import is_hip
class XFormersBackend: class XFormersBackend:
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
"""
def __init__( def __init__(
self, self,
@ -55,19 +70,18 @@ class XFormersBackend:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
Args: Args:
query: shape = [batch_size, seq_len, num_heads * head_size] query: shape = [num_tokens, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size, value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size] block_size]
input_metadata: metadata for the inputs. input_metadata: metadata for the inputs.
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
batch_size, seq_len, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
@ -82,9 +96,10 @@ class XFormersBackend:
if input_metadata.is_prompt: if input_metadata.is_prompt:
# Prompt run. # Prompt run.
# key_cache and value_cache are None when it is a profiling run.
# block tables are empty if the prompt has never been computed.
if (key_cache is None or value_cache is None if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0): or input_metadata.block_tables.numel() == 0):
# normal attention
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of # project the key and value tensors to the desired number of
@ -103,61 +118,33 @@ class XFormersBackend:
self.num_queries_per_kv, self.num_queries_per_kv,
value.shape[-1]) value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
if self.use_ref_attention: if self.use_ref_attention:
output = _ref_masked_attention( print("ref attention used.")
query, output = torch.empty_like(query)
key, start = 0
value, for _, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = _ref_masked_attention(
query[None, start:end],
key[None, start:end],
value[None, start:end],
self.num_heads, self.num_heads,
self.num_kv_heads, self.num_kv_heads,
self.head_size, self.head_size,
self.scale, self.scale,
) )
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out)
start += prompt_len
# Using view got RuntimeError: view size is not compatible # Using view got RuntimeError: view size is not compatible
# with input tensor's size and stride (at least one # with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces). # dimension spans across two contiguous subspaces).
# Use reshape instead. # Use reshape instead.
return output.reshape(batch_size, seq_len, hidden_size) return output.reshape(num_tokens, hidden_size)
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)
output = self._run_memory_efficient_xformer_forward(
query, key, value, input_metadata)
else: else:
# prefix-enabled attention # prefix-enabled attention
output = PagedAttentionImpl.forward_prefix( output = PagedAttentionImpl.forward_prefix(
@ -182,41 +169,117 @@ class XFormersBackend:
) )
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(-1, self.num_heads * self.head_size)
def _run_memory_efficient_xformer_forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention.
"""
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
input_metadata.prompt_lens)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = [attn_bias]
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, query.dtype,
input_metadata)
op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
is_hip()) else None
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias[0],
p=0.0,
scale=self.scale,
op=op)
return out.view_as(query)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(query)
start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=op)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output
def _make_alibi_bias( def _make_alibi_bias(
alibi_slopes: torch.Tensor, alibi_slopes: torch.Tensor,
num_kv_heads: int, num_kv_heads: int,
batch_size: int,
seq_len: int,
dtype: torch.dtype, dtype: torch.dtype,
input_metadata: InputMetadata,
) -> LowerTriangularMaskWithTensorBias: ) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype) attn_biases = []
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi # the bias below more accurately follows the original ALiBi
# paper. # paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias = bias[None, :] - bias[:, None] bias = bias[None, :] - bias[:, None]
# When using custom attention bias, xformers requires the bias to padded_len = (prompt_len + 7) // 8 * 8
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0] num_heads = alibi_slopes.shape[0]
bias = torch.empty( bias = torch.empty(
batch_size, 1, # batch size
num_heads, num_heads,
seq_len, prompt_len,
padded_len, padded_len,
device=alibi_slopes.device, device=alibi_slopes.device,
dtype=dtype, dtype=dtype,
)[:, :, :, :seq_len].copy_(bias) )[:, :, :, :prompt_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None]) bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads: if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
attn_bias = LowerTriangularMaskWithTensorBias(bias) attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
return attn_bias
return attn_biases
def _check_use_ref_attention() -> bool: def _check_use_ref_attention() -> bool:
@ -239,7 +302,6 @@ def _ref_masked_attention(
query = query.view(-1, num_heads, head_size) query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size) key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size)
seq_len, _, _ = query.shape seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len, attn_mask = torch.triu(torch.ones(seq_len,
seq_len, seq_len,

View File

@ -128,11 +128,12 @@ class PagedAttentionImpl:
output, output,
key_cache, key_cache,
value_cache, value_cache,
input_metadata.block_tables, # [BS, max_block_per_request] input_metadata.block_tables,
input_metadata.start_loc, # subquery_start_loc is (batch_size + 1,)
input_metadata.prompt_lens, input_metadata.subquery_start_loc[:-1],
input_metadata.prompt_lens_tensor,
input_metadata.context_lens, input_metadata.context_lens,
input_metadata.max_seq_len, input_metadata.max_subquery_len,
alibi_slopes, alibi_slopes,
) )
return output return output

View File

@ -128,7 +128,6 @@ def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0, return hidden_states.index_select(0,
sampling_metadata.selected_token_indices) sampling_metadata.selected_token_indices)

View File

@ -28,9 +28,12 @@ logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8 LORA_WARMUP_RANK = 8
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. _BATCH_SIZE_ALIGNMENT = 8
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed. # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
class ModelRunner: class ModelRunner:
@ -107,8 +110,7 @@ class ModelRunner:
), "Model does not have embedding_padding_modules" ), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_num_batched_tokens, self.vocab_size,
self.scheduler_config.max_paddings, self.vocab_size,
self.lora_config, self.device, self.model.embedding_modules, self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules) self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
@ -116,10 +118,13 @@ class ModelRunner:
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
self.block_size = block_size self.block_size = block_size
max_num_blocks = (self.max_context_len_to_capture + block_size -
1) // block_size
self.graph_block_tables = np.zeros( self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
dtype=np.int32)
def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_context_len_to_capture + block_size - 1) // block_size
def _prepare_prompt( def _prepare_prompt(
self, self,
@ -127,9 +132,9 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
List[int], List[int], Set[LoRARequest]]: List[int], List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[int] = []
input_positions: List[List[int]] = [] input_positions: List[int] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[int] = []
lora_index_mapping: List[int] = [] lora_index_mapping: List[int] = []
lora_prompt_mapping: List[int] = [] lora_prompt_mapping: List[int] = []
lora_requests: Set[LoRARequest] = set() lora_requests: Set[LoRARequest] = set()
@ -158,16 +163,18 @@ class ModelRunner:
computed_len = len(computed_block_nums) * self.block_size computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:] prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums) prefix_block_tables.append(computed_block_nums)
context_len = computed_len
else: else:
prefix_block_tables.append([]) prefix_block_tables.append([])
context_len = 0
# actual prompt lens # actual prompt lens
context_lens.append(computed_len) context_lens.append(context_len)
subquery_lens.append(prompt_len - computed_len) subquery_lens.append(prompt_len - computed_len)
input_tokens.append(prompt_tokens) input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt # NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence. # is always the first token in the sequence.
input_positions.append( input_positions.extend(
list(range(computed_len, computed_len + len(prompt_tokens)))) list(range(computed_len, computed_len + len(prompt_tokens))))
lora_id = seq_group_metadata.lora_int_id lora_id = seq_group_metadata.lora_int_id
@ -175,7 +182,7 @@ class ModelRunner:
if lora_id > 0: if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request) lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping.append([lora_id] * (prompt_len - computed_len)) lora_index_mapping += [lora_id] * (prompt_len - computed_len)
lora_prompt_mapping.extend( lora_prompt_mapping.extend(
[lora_id] * [lora_id] *
(prompt_len - computed_len (prompt_len - computed_len
@ -184,11 +191,10 @@ class ModelRunner:
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized # During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping. # yet. In this case, we just use a dummy slot mapping.
slot_mapping.append([_PAD_SLOT_ID] * prompt_len) slot_mapping.extend([_PAD_SLOT_ID] * prompt_len)
continue continue
# Compute the slot mapping. # Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id] block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window). # where start_idx is max(0, prompt_len - sliding_window).
@ -203,35 +209,30 @@ class ModelRunner:
start_idx = max(0, prompt_len - self.sliding_window) start_idx = max(0, prompt_len - self.sliding_window)
for i in range(computed_len, prompt_len): for i in range(computed_len, prompt_len):
if i < start_idx: if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID) slot_mapping.append(_PAD_SLOT_ID)
continue continue
block_number = block_table[i // self.block_size] block_number = block_table[i // self.block_size]
block_offset = i % self.block_size block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot) slot_mapping.append(slot)
max_prompt_len = max(subquery_lens) max_subquery_len = max(subquery_lens)
assert max_prompt_len > 0 max_seq_len = max(prompt_lens)
input_tokens = _make_tensor_with_pad(input_tokens, num_prompt_tokens = len(input_tokens)
max_prompt_len, assert max_subquery_len > 0
pad=0,
input_tokens = torch.tensor(input_tokens,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = torch.tensor(input_positions,
max_prompt_len,
pad=0,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
lora_index_mapping = [ lora_index_mapping = lora_index_mapping
_pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping
]
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=self.device) device=self.device)
@ -244,22 +245,45 @@ class ModelRunner:
dtype=torch.int, dtype=torch.int,
device=self.device, device=self.device,
) )
start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len, # Query length can be shorter than key (i.e., prompt) when prefill
max_prompt_len, # is chunked or prefix cached.
subquery_lens_tensor = torch.tensor(subquery_lens,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens, prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=self.device)
torch.cumsum(subquery_lens_tensor,
dim=0,
dtype=subquery_start_loc.dtype,
out=subquery_start_loc[1:])
torch.cumsum(prompt_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
prompt_lens=prompt_lens_tensor, prompt_lens=prompt_lens,
max_seq_len=max_prompt_len, prompt_lens_tensor=prompt_lens_tensor,
start_loc=start_loc_tensor, num_prompt_tokens=num_prompt_tokens,
num_generation_tokens=0,
max_subquery_len=max_subquery_len,
max_context_len=None, max_context_len=None,
max_seq_len=max_seq_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=False, use_cuda_graph=False,
@ -275,9 +299,9 @@ class ModelRunner:
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
Set[LoRARequest]]: Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[int] = []
input_positions: List[List[int]] = [] input_positions: List[int] = []
slot_mapping: List[List[int]] = [] slot_mapping: List[int] = []
context_lens: List[int] = [] context_lens: List[int] = []
block_tables: List[List[int]] = [] block_tables: List[List[int]] = []
lora_index_mapping: List[int] = [] lora_index_mapping: List[int] = []
@ -296,11 +320,11 @@ class ModelRunner:
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id] seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id() generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token]) input_tokens.append(generation_token)
seq_len = seq_data.get_len() seq_len = seq_data.get_len()
position = seq_len - 1 position = seq_len - 1
input_positions.append([position]) input_positions.append(position)
context_len = seq_len if self.sliding_window is None else min( context_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window) seq_len, self.sliding_window)
@ -310,8 +334,8 @@ class ModelRunner:
block_number = block_table[position // self.block_size] block_number = block_table[position // self.block_size]
block_offset = position % self.block_size block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset slot = block_number * self.block_size + block_offset
slot_mapping.append([slot]) slot_mapping.append(slot)
lora_index_mapping.append([lora_id]) lora_index_mapping.append(lora_id)
lora_prompt_mapping.append(lora_id) lora_prompt_mapping.append(lora_id)
if self.sliding_window is not None: if self.sliding_window is not None:
@ -320,6 +344,9 @@ class ModelRunner:
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table) block_tables.append(block_table)
# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
batch_size = len(input_tokens) batch_size = len(input_tokens)
max_context_len = max(context_lens) max_context_len = max(context_lens)
use_captured_graph = ( use_captured_graph = (
@ -327,31 +354,24 @@ class ModelRunner:
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_context_len <= self.max_context_len_to_capture) and max_context_len <= self.max_context_len_to_capture)
if use_captured_graph: if use_captured_graph:
# Pad the input tokens, positions, and slot mapping to match the
# batch size of the captured graph.
graph_batch_size = _get_graph_batch_size(batch_size) graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size assert graph_batch_size >= batch_size
for _ in range(graph_batch_size - batch_size): for _ in range(graph_batch_size - batch_size):
input_tokens.append([]) input_tokens.append(0)
input_positions.append([]) input_positions.append(0)
slot_mapping.append([]) slot_mapping.append(_PAD_SLOT_ID)
context_lens.append(1) context_lens.append(1)
block_tables.append([]) block_tables.append([])
lora_index_mapping.append(0)
batch_size = graph_batch_size batch_size = graph_batch_size
input_tokens = _make_tensor_with_pad(input_tokens, input_tokens = torch.tensor(input_tokens,
max_len=1,
pad=0,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = torch.tensor(input_positions,
max_len=1,
pad=0,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = torch.tensor(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
@ -359,6 +379,12 @@ class ModelRunner:
device=self.device) device=self.device)
if use_captured_graph: if use_captured_graph:
# When using cuda-graph all these tensors should be
# padded.
assert context_lens.shape[0] == input_tokens.shape[0]
assert context_lens.shape[0] == input_positions.shape[0]
assert context_lens.shape[0] == slot_mapping.shape[0]
# The shape of graph_block_tables is # The shape of graph_block_tables is
# [max batch size, max context len // block size]. # [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size] input_block_tables = self.graph_block_tables[:batch_size]
@ -377,17 +403,18 @@ class ModelRunner:
device=self.device, device=self.device,
) )
lora_index_mapping = [
_pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping
]
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
prompt_lens=None, prompt_lens=None,
max_seq_len=None, prompt_lens_tensor=None,
start_loc=None, num_prompt_tokens=0,
num_generation_tokens=len(input_tokens),
max_subquery_len=None,
max_context_len=max_context_len, max_context_len=max_context_len,
max_seq_len=None,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
@ -411,7 +438,6 @@ class ModelRunner:
categorized_sampled_token_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0
pin_memory = not self.in_wsl and not self.device_config.is_neuron pin_memory = not self.in_wsl and not self.device_config.is_neuron
max_subquery_len = max(subquery_lens) if subquery_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list): for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys()) seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params sampling_params = seq_group_metadata.sampling_params
@ -439,7 +465,7 @@ class ModelRunner:
selected_token_start_idx + subquery_len - 1)) selected_token_start_idx + subquery_len - 1))
selected_token_indices.append(selected_token_start_idx + selected_token_indices.append(selected_token_start_idx +
subquery_len - 1) subquery_len - 1)
selected_token_start_idx += max_subquery_len selected_token_start_idx += subquery_len
if sampling_params.seed is not None: if sampling_params.seed is not None:
seq_group_metadata.state.generator = torch.Generator( seq_group_metadata.state.generator = torch.Generator(
@ -521,11 +547,8 @@ class ModelRunner:
subquery_lens) subquery_lens)
if self.lora_config: if self.lora_config:
flat_lora_index_mapping = [
item for sublist in lora_index_mapping for item in sublist
]
lora_mapping = LoRAMapping( lora_mapping = LoRAMapping(
flat_lora_index_mapping, lora_index_mapping,
lora_prompt_mapping, lora_prompt_mapping,
) )
else: else:
@ -679,6 +702,18 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None: def capture_model(self, kv_caches: List[KVCache]) -> None:
"""Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number
of batched tokens are larger than 200. And since CUDA graph
requires fixed sized tensors, supporting large/variable batch
size requires high GPU memory overhead. Thus, vLLM only captures
decoding requests. Mixed batch (chunked prefill + decoding) or
prefill requests are not captured.
Since it is used for decoding-only, it assumes there's only 1 token
per sequence in the batch.
"""
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs. # deleted before the CUDA graphs.
self.cupy_nccl_backend = cupy_utils.get_nccl_backend() self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
@ -697,10 +732,9 @@ class ModelRunner:
# Prepare dummy inputs. These will be reused for all batch sizes. # Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, 1, input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID) slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
@ -726,9 +760,14 @@ class ModelRunner:
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping[:batch_size], slot_mapping=slot_mapping[:batch_size],
prompt_lens=None, prompt_lens=None,
max_seq_len=None, prompt_lens_tensor=None,
start_loc=None, num_prompt_tokens=0,
num_generation_tokens=batch_size,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture, max_context_len=self.max_context_len_to_capture,
max_seq_len=None,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens[:batch_size], context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size], block_tables=block_tables[:batch_size],
use_cuda_graph=True, use_cuda_graph=True,
@ -845,7 +884,6 @@ class CUDAGraphRunner:
non_blocking=True) non_blocking=True)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables, self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
non_blocking=True) non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()
@ -877,17 +915,28 @@ def _make_tensor_with_pad(
dtype: torch.dtype, dtype: torch.dtype,
device: Optional[Union[str, torch.device]], device: Optional[Union[str, torch.device]],
) -> torch.Tensor: ) -> torch.Tensor:
"""Make a padded tensor of a 2D inputs.
The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device) return torch.tensor(padded_x, dtype=dtype, device=device)
def _get_graph_batch_size(batch_size: int) -> int: def _get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2: if batch_size <= 2:
return batch_size return batch_size
elif batch_size <= 4: elif batch_size <= 4:
return 4 return 4
else: else:
return (batch_size + 7) // 8 * 8 return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
def _async_h2d( def _async_h2d(