[Performance] Optimize e2e overheads: Reduce python allocations (#7162)

This commit is contained in:
Alexander Matveev 2024-08-09 00:34:28 -04:00 committed by GitHub
parent 73388c07a4
commit e02ac55617
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 549 additions and 124 deletions

View File

@ -259,7 +259,11 @@ class FlashAttentionMetadataBuilder(
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.

View File

@ -68,13 +68,21 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
# tokens are masked and the slot mapping will be
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
block_table = block_tables[seq_id]
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
def add_slot(i):
block_number = block_table[i // block_size]
block_offset = i % block_size
slot = block_number * block_size + block_offset
slot_mapping.append(slot)
if start_idx == 0 and (seq_len - context_len) == 1:
# Optimization for common-case of decoding next token
add_slot(seq_len - 1)
else:
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
for i in range(max(start_idx, context_len), seq_len):
add_slot(i)
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')

View File

@ -1,5 +1,5 @@
"""Token blocks."""
from typing import List
from typing import List, Optional
from vllm.utils import Device
@ -37,5 +37,47 @@ class PhysicalTokenBlock:
f'computed={self.computed})')
# Mapping: logical block number -> physical block.
BlockTable = List[PhysicalTokenBlock]
class BlockTable:
"""Holds a list of blocks with caching of their associated block_ids
"""
def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
self._blocks: List[PhysicalTokenBlock] = []
self._block_ids: List[int] = []
if blocks is not None:
for block in blocks:
self.append(block)
def append(self, block: PhysicalTokenBlock):
self._blocks.append(block)
self._block_ids.append(block.block_number)
def __len__(self) -> int:
return len(self._blocks)
def __getitem__(self, key):
return self._blocks[key]
def __setitem__(self, key, value):
if isinstance(key, slice):
blocks = value
self._blocks[key] = blocks
self._block_ids[key] = [b.block_number for b in blocks]
else:
block = value
self._blocks[key] = block
self._block_ids[key] = block.block_number
def reset(self):
self._blocks = []
self._block_ids = []
def copy(self) -> "BlockTable":
return BlockTable(self._blocks)
def list(self) -> List[PhysicalTokenBlock]:
return self._blocks
def ids(self) -> List[int]:
return self._block_ids

View File

@ -170,7 +170,7 @@ class UncachedBlockAllocator(BlockAllocatorBase):
self.num_blocks = num_blocks
# Initialize the free blocks.
self.free_blocks: BlockTable = []
self.free_blocks: List[PhysicalTokenBlock] = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
@ -256,6 +256,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}
# Mapping: req_id -> BlockTable
# Note that each SequenceGroup has a unique
# request ID
@ -299,7 +300,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = seq.n_blocks
block_table: BlockTable = []
block_table: BlockTable = BlockTable()
for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
@ -326,15 +327,19 @@ class BlockSpaceManagerV1(BlockSpaceManager):
#
# NOTE: Here we assume that all sequences in the group have the same
# decoder prompt.
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
seq = wait_seqs[0]
block_table: BlockTable = \
self._allocate_sequence(seq,
seq_group.num_seqs(),
is_encoder_decoder)
# Assign the self-attention block tables for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
if len(wait_seqs) == 1:
self.block_tables[wait_seqs[0].seq_id] = block_table
else:
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
# Allocate encoder sequence
if is_encoder_decoder:
@ -476,6 +481,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused.
# In this case the block tables will contain repeated blocks.
# When forking, we must make sure that each block's `ref_count`
@ -527,7 +533,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
dest_allocator: BlockAllocatorBase,
mapping: Dict[PhysicalTokenBlock,
PhysicalTokenBlock]) -> BlockTable:
new_block_table = []
new_block_table: BlockTable = BlockTable()
for from_block in block_table:
if from_block in mapping:
@ -553,8 +559,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.cpu_allocator,
self.gpu_allocator,
self.cpu_allocator, self.gpu_allocator,
mapping)
if seq_group.is_encoder_decoder():
@ -580,8 +585,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
self.block_tables[seq.seq_id] = \
self._swap_block_table(self.block_tables[seq.seq_id],
self.gpu_allocator,
self.cpu_allocator,
self.gpu_allocator, self.cpu_allocator,
mapping)
if seq_group.is_encoder_decoder():
@ -636,8 +640,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
self.cross_block_tables.clear()
def get_block_table(self, seq: Sequence) -> List[int]:
block_table = self.block_tables[seq.seq_id]
return [block.block_number for block in block_table]
return self.block_tables[seq.seq_id].ids()
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
block_table = self.cross_block_tables[seq_group.request_id]

View File

@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.utils import PyObjectCache
logger = init_logger(__name__)
@ -176,10 +177,10 @@ class SchedulerRunningOutputs:
enough memory, it can be preempted (for recompute) or swapped out.
"""
# Selected sequences that are running and in a decoding phase.
decode_seq_groups: List[SequenceGroup]
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are running and in a prefill phase.
# I.e., it means the prefill has been chunked.
prefill_seq_groups: List[SequenceGroup]
prefill_seq_groups: List[ScheduledSequenceGroup]
# The preempted sequences.
preempted: List[SequenceGroup]
# Sequences that are swapped out.
@ -191,6 +192,10 @@ class SchedulerRunningOutputs:
# The number of slots for lookahead decoding.
num_lookahead_slots: int
# Optimization for fast-access to seq_group lists
decode_seq_groups_list: List[SequenceGroup]
prefill_seq_groups_list: List[SequenceGroup]
@classmethod
def create_empty(cls) -> "SchedulerRunningOutputs":
return SchedulerRunningOutputs(
@ -201,6 +206,8 @@ class SchedulerRunningOutputs:
blocks_to_swap_out=[],
blocks_to_copy=[],
num_lookahead_slots=0,
decode_seq_groups_list=[],
prefill_seq_groups_list=[],
)
@ -259,6 +266,30 @@ class SchedulerPrefillOutputs:
)
def seq_group_metadata_builder():
return SequenceGroupMetadata(request_id="",
is_prompt=False,
seq_data={},
sampling_params=None,
block_tables={})
def scheduler_running_outputs_builder():
return SchedulerRunningOutputs(decode_seq_groups=[],
prefill_seq_groups=[],
preempted=[],
swapped_out=[],
blocks_to_swap_out=[],
blocks_to_copy=[],
num_lookahead_slots=0,
prefill_seq_groups_list=[],
decode_seq_groups_list=[])
def scheduled_seq_group_builder():
return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
class Scheduler:
def __init__(
@ -331,6 +362,14 @@ class Scheduler:
else 0)
self.num_cumulative_preemption: int = 0
# Used to cache python objects
self._seq_group_metadata_cache: PyObjectCache = PyObjectCache(
seq_group_metadata_builder)
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
scheduler_running_outputs_builder)
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
scheduled_seq_group_builder)
@property
def lora_enabled(self) -> bool:
return bool(self.lora_config)
@ -441,14 +480,30 @@ class Scheduler:
Returns:
SchedulerRunningOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = []
blocks_to_copy: List[Tuple[int, int]] = []
ret: SchedulerRunningOutputs = \
self._scheduler_running_outputs_cache.get_object()
ret.blocks_to_swap_out.clear()
ret.blocks_to_copy.clear()
ret.decode_seq_groups.clear()
ret.prefill_seq_groups.clear()
ret.preempted.clear()
ret.swapped_out.clear()
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
preempted: List[SequenceGroup] = []
swapped_out: List[SequenceGroup] = []
ret.num_lookahead_slots = self._get_num_lookahead_slots(
is_prefill=False)
ret.decode_seq_groups_list.clear()
ret.prefill_seq_groups_list.clear()
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
prefill_seq_groups: List[
ScheduledSequenceGroup] = ret.prefill_seq_groups
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
@ -497,15 +552,19 @@ class Scheduler:
else:
self._append_slots(seq_group, blocks_to_copy)
is_prefill = seq_group.is_prefill()
scheduled_seq_group: ScheduledSequenceGroup = \
self._scheduled_seq_group_cache.get_object()
scheduled_seq_group.seq_group = seq_group
if is_prefill:
prefill_seq_groups.append(
ScheduledSequenceGroup(
seq_group=seq_group,
token_chunk_size=num_running_tokens))
scheduled_seq_group.token_chunk_size = num_running_tokens
prefill_seq_groups.append(scheduled_seq_group)
ret.prefill_seq_groups_list.append(seq_group)
else:
decode_seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=1))
scheduled_seq_group.token_chunk_size = 1
decode_seq_groups.append(scheduled_seq_group)
ret.decode_seq_groups_list.append(seq_group)
budget.add_num_batched_tokens(seq_group.request_id,
num_running_tokens)
# OPTIMIZATION: Note that get_max_num_running_seqs is
@ -518,15 +577,10 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
return SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
preempted=preempted,
swapped_out=swapped_out,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=False))
self._scheduler_running_outputs_cache.reset()
self._scheduled_seq_group_cache.reset()
return ret
def _schedule_swapped(
self,
@ -820,11 +874,15 @@ class Scheduler:
# Update waiting requests.
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
if len(prefills.seq_groups) > 0:
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(running_scheduled.decode_seq_groups_list)
if len(swapped_in.decode_seq_groups) > 0:
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
@ -834,18 +892,30 @@ class Scheduler:
# doesn't allow chunked prefills.
assert len(running_scheduled.prefill_seq_groups) == 0
assert len(swapped_in.prefill_seq_groups) == 0
# Merge lists
num_prefill_groups = len(prefills.seq_groups)
if num_prefill_groups > 0:
scheduled_seq_groups = prefills.seq_groups
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
else:
scheduled_seq_groups = running_scheduled.decode_seq_groups
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
blocks_to_copy = running_scheduled.blocks_to_copy
blocks_to_copy.extend(swapped_in.blocks_to_copy)
ignored_seq_groups = prefills.ignored_seq_groups
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups),
num_prefill_groups=len(prefills.seq_groups),
scheduled_seq_groups=scheduled_seq_groups,
num_prefill_groups=num_prefill_groups,
num_batched_tokens=budget.num_batched_tokens,
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
blocks_to_copy=running_scheduled.blocks_to_copy +
swapped_in.blocks_to_copy,
ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups,
blocks_to_copy=blocks_to_copy,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
preempted=preempted,
@ -963,6 +1033,9 @@ class Scheduler:
scheduler_outputs = self._schedule()
now = time.time()
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, scheduled_seq_group in enumerate(
@ -971,10 +1044,15 @@ class Scheduler:
token_chunk_size = scheduled_seq_group.token_chunk_size
seq_group.maybe_set_first_scheduled_time(now)
seq_group_metadata = self._seq_group_metadata_cache.get_object()
seq_group_metadata.seq_data.clear()
seq_group_metadata.block_tables.clear()
# seq_id -> SequenceData
seq_data: Dict[int, SequenceData] = {}
seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data
# seq_id -> physical block numbers
block_tables: Dict[int, List[int]] = {}
block_tables: Dict[int,
List[int]] = seq_group_metadata.block_tables
if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
@ -993,9 +1071,10 @@ class Scheduler:
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)
common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
if self.cache_config.enable_prefix_caching:
common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
do_sample = True
if seq_group.is_prefill():
@ -1014,7 +1093,8 @@ class Scheduler:
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
is_prompt = seq_group.is_prefill()
seq_group_metadata = SequenceGroupMetadata(
seq_group_metadata.__init__(
request_id=seq_group.request_id,
is_prompt=is_prompt,
seq_data=seq_data,
@ -1045,6 +1125,8 @@ class Scheduler:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
self._seq_group_metadata_cache.reset()
return seq_group_metadata_list, scheduler_outputs
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
@ -1093,7 +1175,8 @@ class Scheduler:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
blocks_to_copy.extend(cows)
if len(cows) > 0:
blocks_to_copy.extend(cows)
def _preempt(
self,

View File

@ -1,10 +1,12 @@
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedvLLMParameter)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingMetadataCache)
from vllm.model_executor.utils import set_random_seed
__all__ = [
"SamplingMetadata",
"SamplingMetadataCache",
"set_random_seed",
"BasevLLMParameter",
"PackedvLLMParameter",

View File

@ -8,8 +8,9 @@ import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
make_tensor_with_pad, maybe_expand_dim)
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad,
maybe_expand_dim)
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
@ -62,6 +63,39 @@ class SequenceGroupToSample:
assert self.query_len is not None
def gen_seq_group_to_sample_builder(num_seqs: int):
return lambda: SequenceGroupToSample(
seq_ids=[0] * num_seqs,
sampling_params=None,
seq_data=None, # type: ignore
seq_len=0,
query_len=0,
generator=None,
is_prompt=True,
prompt_logprob_indices=[],
sample_indices=[])
class SamplingMetadataCache:
"""Used to cache SamplingMetadata objects between scheduler iterations
"""
def __init__(self):
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
def get_cached_seq_group_to_sample(self, num_seqs):
if num_seqs not in self._seq_group_to_sample_cache:
self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
gen_seq_group_to_sample_builder(num_seqs))
obj = self._seq_group_to_sample_cache[num_seqs].get_object()
return obj
def reset(self):
for cache in self._seq_group_to_sample_cache.values():
cache.reset()
class SamplingMetadata:
"""Metadata for input sequences. Used in sampler.
@ -121,6 +155,7 @@ class SamplingMetadata:
device: str,
pin_memory: bool,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> "SamplingMetadata":
(
seq_groups,
@ -128,7 +163,7 @@ class SamplingMetadata:
categorized_sample_indices,
num_prompts,
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
device, generators)
device, generators, cache)
selected_token_indices = async_tensor_h2d(selected_token_indices,
dtype=torch.long,
target_device=device,
@ -164,6 +199,7 @@ def _prepare_seq_groups(
query_lens: Optional[List[int]],
device: str,
generators: Optional[Dict[str, torch.Generator]] = None,
cache: Optional[SamplingMetadataCache] = None,
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
SamplingType, List[Tuple[int, int]]], int]:
"""Prepare sequence groups and indices for sampling.
@ -210,15 +246,27 @@ def _prepare_seq_groups(
num_prompts = 0
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
seq_ids = seq_group_metadata.seq_data.keys()
if cache is not None:
sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
for j, seq_id in enumerate(seq_ids):
sample_obj.seq_ids[j] = seq_id
sample_obj.prompt_logprob_indices.clear()
sample_obj.sample_indices.clear()
sampling_params = seq_group_metadata.sampling_params
is_prompt = seq_group_metadata.is_prompt
generator: Optional[torch.Generator] = None
# If the current seq group is in decode stage, it is None.
seq_len: Optional[int] = None
query_len: Optional[int] = None
prompt_logprob_indices: List[int] = []
sample_indices: List[int] = []
prompt_logprob_indices: List[int] = \
sample_obj.prompt_logprob_indices if cache is not None else []
sample_indices: List[int] = \
sample_obj.sample_indices if cache is not None else []
do_sample = seq_group_metadata.do_sample
if seq_group_metadata.is_prompt:
@ -290,9 +338,16 @@ def _prepare_seq_groups(
logit_idx += sample_len
sample_idx += sample_len
seq_groups.append(
SequenceGroupToSample(
seq_ids=seq_ids,
if cache is not None:
sample_obj.sampling_params = sampling_params
sample_obj.seq_data = seq_group_metadata.seq_data
sample_obj.seq_len = seq_len
sample_obj.query_len = query_len
sample_obj.generator = generator
sample_obj.is_prompt = is_prompt
else:
sample_obj = SequenceGroupToSample(
seq_ids=list(seq_ids),
sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data,
seq_len=seq_len,
@ -300,7 +355,13 @@ def _prepare_seq_groups(
generator=generator,
is_prompt=is_prompt,
prompt_logprob_indices=list(prompt_logprob_indices),
sample_indices=list(sample_indices)))
sample_indices=list(sample_indices))
seq_groups.append(sample_obj)
if cache is not None:
cache.reset()
return (seq_groups, selected_token_indices, categorized_sample_indices,
num_prompts)

View File

@ -139,7 +139,7 @@ class RequestOutput:
CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.data._output_token_ids, # type: ignore
seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),

View File

@ -1,7 +1,6 @@
"""Sequence and its related classes."""
import copy
import enum
import math
from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
@ -330,7 +329,7 @@ class Sequence:
@property
def n_blocks(self) -> int:
return math.ceil(self.get_len() / self.block_size)
return (self.get_len() + self.block_size - 1) // self.block_size
@property
def prompt(self) -> Optional[str]:
@ -514,7 +513,9 @@ class SequenceGroup:
) -> None:
self.request_id = request_id
self.seqs = seqs
self.is_single_seq = len(seqs) == 1
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.metrics = RequestMetrics(arrival_time=arrival_time,
last_token_time=arrival_time,
@ -635,6 +636,10 @@ class SequenceGroup:
) -> List[Sequence]:
if status is None:
return self.seqs
if self.is_single_seq:
return self.seqs if self.seqs[0].status == status else []
return [seq for seq in self.seqs if seq.status == status]
def is_encoder_decoder(self) -> bool:
@ -644,6 +649,9 @@ class SequenceGroup:
return self.encoder_seq
def get_unfinished_seqs(self) -> List[Sequence]:
if self.is_single_seq:
return self.seqs if not self.seqs[0].is_finished() else []
return [seq for seq in self.seqs if not seq.is_finished()]
def get_finished_seqs(self) -> List[Sequence]:
@ -668,12 +676,21 @@ class SequenceGroup:
if status is None:
return len(self.seqs)
if self.is_single_seq:
return 1 if self.seqs[0].status == status else 0
return len(self.get_seqs(status))
def num_unfinished_seqs(self) -> int:
if self.is_single_seq:
return 1 if not self.seqs[0].is_finished() else 0
return len(self.get_unfinished_seqs())
def num_finished_seqs(self) -> int:
if self.is_single_seq:
return 1 if self.seqs[0].is_finished() else 0
return len(self.get_finished_seqs())
def find(self, seq_id: int) -> Sequence:
@ -686,12 +703,14 @@ class SequenceGroup:
raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
self.seqs.append(seq)
self.is_single_seq = len(self.seqs) == 1
def remove(self, seq_id: int) -> None:
seq = self.seqs_dict.pop(seq_id, None)
if seq is None:
raise ValueError(f"Sequence {seq_id} not found.")
self.seqs.remove(seq)
self.is_single_seq = len(self.seqs) == 1
def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.seqs)
@ -775,9 +794,10 @@ class SequenceGroupMetadata:
# TODO: We should maintain this states out of the sequence group.
self.num_speculative_tokens = None
if self._token_chunk_size is None:
if seq_data is not None and self._token_chunk_size is None:
if is_prompt:
self._token_chunk_size = list(seq_data.values())[0].get_len()
self._token_chunk_size = next(iter(
seq_data.values())).get_len()
else:
self._token_chunk_size = 1

View File

@ -261,6 +261,44 @@ class LRUCache(Generic[T]):
self.cache.clear()
class PyObjectCache:
"""Used to cache python objects to avoid object allocations
across scheduler iterations.
"""
def __init__(self, obj_builder):
self._obj_builder = obj_builder
self._index = 0
self._obj_cache = []
for _ in range(128):
self._obj_cache.append(self._obj_builder())
def _grow_cache(self):
# Double the size of the cache
num_objs = len(self._obj_cache)
for _ in range(num_objs):
self._obj_cache.append(self._obj_builder())
def get_object(self):
"""Returns a pre-allocated cached object. If there is not enough
objects, then the cache size will double.
"""
if self._index >= len(self._obj_cache):
self._grow_cache()
assert self._index < len(self._obj_cache)
obj = self._obj_cache[self._index]
self._index += 1
return obj
def reset(self):
"""Makes all cached-objects available for the next scheduler iteration.
"""
self._index = 0
def is_hip() -> bool:
return torch.version.hip is not None

View File

@ -1,5 +1,6 @@
import dataclasses
import gc
import itertools
import time
import warnings
import weakref
@ -35,7 +36,7 @@ from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
@ -50,8 +51,8 @@ from vllm.prompt_adapter.worker_manager import (
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists,
get_kv_cache_torch_dtype, is_hip,
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
flatten_2d_lists, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
@ -178,6 +179,20 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
class InterDataForSeqGroup:
"""Intermediate data for the current sequence group."""
def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.input_positions[0].clear() # type: ignore
self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore
self.query_lens[0] = 0 # type: ignore
self.context_lens[0] = 0 # type: ignore
self.curr_sliding_window_blocks[0] = 0 # type: ignore
self.lora_index_mapping.clear() # type: ignore
self.lora_prompt_mapping.clear() # type: ignore
self.lora_requests.clear() # type: ignore
self.prompt_adapter_index_mapping.clear() # type: ignore
self.prompt_adapter_prompt_mapping.clear() # type: ignore
def __init__(
self,
*,
@ -220,35 +235,121 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Whether the prefix cache is hit (prefill only).
prefix_cache_hit: bool = False,
reinit: bool = False,
reinit_use_defaults: bool = False,
):
if reinit:
assert len(self.seq_ids) == len(seq_ids) # type: ignore
for i, seq_id in enumerate(seq_ids):
self.seq_ids[i] = seq_id # type: ignore
else:
self.seq_ids = seq_ids
self.request_id = request_id
self.seq_ids = seq_ids
self.is_prompt = is_prompt
self.block_tables = block_tables
self.computed_block_nums = computed_block_nums
self.n_seqs = n_seqs
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
if reinit:
if len(self.seq_ids) == 1 and reinit_use_defaults:
self.simple_reinit()
else:
if input_tokens:
self.input_tokens = input_tokens
else:
for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear()
if input_positions:
self.input_positions = input_positions
else:
for seq_id in range(len(self.seq_ids)):
self.input_positions[seq_id].clear()
if seq_lens:
self.seq_lens = seq_lens
else:
for seq_id in range(len(self.seq_ids)):
self.seq_lens[seq_id] = 0
if orig_seq_lens:
self.orig_seq_lens = orig_seq_lens
else:
for seq_id in range(len(self.seq_ids)):
self.orig_seq_lens[seq_id] = 0
if query_lens:
self.query_lens = query_lens
else:
for seq_id in range(len(self.seq_ids)):
self.query_lens[seq_id] = 0
if context_lens:
self.context_lens = context_lens
else:
for seq_id in range(len(self.seq_ids)):
self.context_lens[seq_id] = 0
if curr_sliding_window_blocks:
self.curr_sliding_window_blocks = \
curr_sliding_window_blocks
else:
for seq_id in range(len(self.seq_ids)):
self.curr_sliding_window_blocks[seq_id] = 0
if lora_index_mapping:
self.lora_index_mapping = lora_index_mapping
else:
self.lora_index_mapping.clear()
if lora_prompt_mapping:
self.lora_prompt_mapping = lora_prompt_mapping
else:
self.lora_prompt_mapping.clear()
if lora_requests:
self.lora_requests = lora_requests
else:
self.lora_requests.clear()
if prompt_adapter_index_mapping:
self.prompt_adapter_index_mapping = \
prompt_adapter_index_mapping
else:
self.prompt_adapter_index_mapping.clear()
if prompt_adapter_prompt_mapping:
self.prompt_adapter_prompt_mapping = \
prompt_adapter_prompt_mapping
else:
self.prompt_adapter_prompt_mapping.clear()
else:
self.input_tokens = input_tokens or []
self.input_positions = input_positions or []
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = \
curr_sliding_window_blocks or []
self.lora_index_mapping = lora_index_mapping or []
self.lora_prompt_mapping = lora_prompt_mapping or []
self.lora_requests = lora_requests or set()
self.prompt_adapter_index_mapping = (
prompt_adapter_index_mapping or [])
self.prompt_adapter_prompt_mapping = (
prompt_adapter_prompt_mapping or [])
self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
or [])
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
or [])
self.prompt_adapter_request = prompt_adapter_request
self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit
self.__post_init__()
if not reinit:
self.__post_init__()
def __post_init__(self):
self.n_seqs = len(self.seq_ids)
@ -261,8 +362,36 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs
self.lora_index_mapping = [[] for _ in range(self.n_seqs)]
self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)]
self.lora_index_mapping = []
self.lora_prompt_mapping = []
def gen_inter_data_builder(self, num_seqs: int):
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
request_id="",
seq_ids=[0] * num_seqs,
is_prompt=True,
block_tables=None,
computed_block_nums=[])
def init_cached_inter_data(self, *args, **kwargs):
assert len(args) == 0
assert "seq_ids" in kwargs
seq_ids = kwargs["seq_ids"]
num_seqs = len(seq_ids)
# The inter-data cache is per model_runner
inter_data_cache = self.runner.inter_data_cache
if num_seqs not in inter_data_cache:
inter_data_cache[num_seqs] = PyObjectCache(
self.gen_inter_data_builder(num_seqs))
obj = inter_data_cache[num_seqs].get_object()
obj.__init__(*args, **kwargs)
return obj
def reset_cached_inter_data(self):
for cache in self.runner.inter_data_cache.values():
cache.reset()
def __init__(self,
runner: "GPUModelRunnerBase",
@ -337,17 +466,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Compute tokens.
if inter_data.is_prompt:
tokens = seq_data.get_token_ids()[context_len:seq_len]
tokens = seq_data.get_token_ids()
if context_len != 0 or seq_len < len(tokens):
tokens = tokens[context_len:seq_len]
else:
# Optimization. get_token_ids requires the entire copy of
# tokens.
tokens = [seq_data.get_last_token_id()]
tokens = seq_data.get_last_token_id()
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx] = tokens
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
if isinstance(tokens, list):
inter_data.input_tokens[seq_idx].extend(tokens)
else:
inter_data.input_tokens[seq_idx].append(tokens)
if (seq_len - context_len) == 1:
inter_data.input_positions[seq_idx].append(seq_len - 1)
else:
inter_data.input_positions[seq_idx].extend(
range(context_len, seq_len))
inter_data.query_lens[
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
@ -471,7 +612,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
"""Add a sequence group to the builder."""
seq_ids = list(seq_group_metadata.seq_data.keys())
seq_ids = seq_group_metadata.seq_data.keys()
n_seqs = len(seq_ids)
is_prompt = seq_group_metadata.is_prompt
@ -479,12 +620,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
assert n_seqs == 1
self.decode_only = False
inter_data = self.InterDataForSeqGroup(
inter_data = self.init_cached_inter_data(
request_id=seq_group_metadata.request_id,
seq_ids=seq_ids,
is_prompt=is_prompt,
block_tables=seq_group_metadata.block_tables,
computed_block_nums=seq_group_metadata.computed_block_nums)
computed_block_nums=seq_group_metadata.computed_block_nums,
reinit=True,
reinit_use_defaults=True)
self.inter_data_list.append(inter_data)
for seq_idx in range(n_seqs):
@ -504,18 +648,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens = flatten_2d_lists([
flatten_2d_lists(inter_data.input_tokens)
for inter_data in self.inter_data_list
])
input_tokens = []
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
if not input_tokens:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return self.model_input_cls()
input_positions = flatten_2d_lists([
flatten_2d_lists(inter_data.input_positions)
for inter_data in self.inter_data_list
])
input_positions = []
for inter_data in self.inter_data_list:
for cur_input_positions in inter_data.input_positions:
input_positions.extend(cur_input_positions)
seq_lens = []
max_decode_seq_len = 0
for inter_data in self.inter_data_list:
@ -523,8 +670,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
if not inter_data.is_prompt:
max_decode_seq_len = max(max_decode_seq_len,
max(inter_data.seq_lens))
query_lens = flatten_2d_lists(
[inter_data.query_lens for inter_data in self.inter_data_list])
query_lens = []
for inter_data in self.inter_data_list:
query_lens.extend(inter_data.query_lens)
# Mapping from request IDs to sequence IDs. Used for Jamba models
# that manages the cache by itself.
request_ids_to_seq_ids = {
@ -547,8 +696,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
batch_size = graph_batch_size
# Tokens and positions.
input_tokens.extend([0] * cuda_graph_pad_size)
input_positions.extend([0] * cuda_graph_pad_size)
if cuda_graph_pad_size:
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
assert self.runner.device is not None
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
self.runner.device,
@ -558,7 +708,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.runner.pin_memory)
# Sequence and query lengths.
seq_lens.extend([1] * cuda_graph_pad_size)
if cuda_graph_pad_size:
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
# Attention metadata.
attn_metadata = self.attn_metadata_builder.build(
@ -574,11 +725,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
flatten_2d_lists(inter_data.lora_index_mapping)
for inter_data in self.inter_data_list
])
lora_index_mapping.extend([0] * cuda_graph_pad_size)
if cuda_graph_pad_size:
lora_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
lora_prompt_mapping = flatten_2d_lists([
flatten_2d_lists(inter_data.lora_prompt_mapping)
for inter_data in self.inter_data_list
])
lora_mapping = LoRAMapping(
**dict(index_mapping=lora_index_mapping,
prompt_mapping=lora_prompt_mapping,
@ -595,7 +749,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.prompt_adapter_index_mapping
for inter_data in self.inter_data_list
])
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
if cuda_graph_pad_size:
prompt_adapter_index_mapping.extend(
itertools.repeat(0, cuda_graph_pad_size))
prompt_adapter_prompt_mapping = flatten_2d_lists([
inter_data.prompt_adapter_prompt_mapping
for inter_data in self.inter_data_list
@ -717,6 +873,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
set_cpu_offload_max_bytes(
int(self.cache_config.cpu_offload_gb * 1024**3))
# Used to cache python objects
self.inter_data_cache: Dict[int, PyObjectCache] = {}
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache()
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with CudaMemoryProfiler() as m:
@ -843,6 +1004,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
for seq_group_metadata in seq_group_metadata_list:
builder.add_seq_group(seq_group_metadata)
builder.reset_cached_inter_data()
return builder.build() # type: ignore
@torch.inference_mode()
@ -1276,7 +1440,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, model_input.seq_lens,
model_input.query_lens, self.device, self.pin_memory,
generators)
generators, self.sampling_metadata_cache)
else:
sampling_metadata = None
is_prompt = (seq_group_metadata_list[0].is_prompt