diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8a895bbdc2dd7..5710aa1930b79 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -259,7 +259,11 @@ class FlashAttentionMetadataBuilder( block_table = block_tables[seq_id] elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): - block_table = block_tables[seq_id][-curr_sliding_window_block:] + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] self.block_tables.append(block_table) # Compute slot mapping. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index f7cb2ee996501..3ca668cb4e029 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -68,13 +68,21 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], # tokens are masked and the slot mapping will be # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. block_table = block_tables[seq_id] - slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len)) - for i in range(max(start_idx, context_len), seq_len): + + def add_slot(i): block_number = block_table[i // block_size] block_offset = i % block_size slot = block_number * block_size + block_offset slot_mapping.append(slot) + if start_idx == 0 and (seq_len - context_len) == 1: + # Optimization for common-case of decoding next token + add_slot(seq_len - 1) + else: + slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len)) + for i in range(max(start_idx, context_len), seq_len): + add_slot(i) + TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') diff --git a/vllm/block.py b/vllm/block.py index 0b8ef7d4b73d9..95286048d9115 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -1,5 +1,5 @@ """Token blocks.""" -from typing import List +from typing import List, Optional from vllm.utils import Device @@ -37,5 +37,47 @@ class PhysicalTokenBlock: f'computed={self.computed})') -# Mapping: logical block number -> physical block. -BlockTable = List[PhysicalTokenBlock] +class BlockTable: + """Holds a list of blocks with caching of their associated block_ids + """ + + def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None): + self._blocks: List[PhysicalTokenBlock] = [] + self._block_ids: List[int] = [] + + if blocks is not None: + for block in blocks: + self.append(block) + + def append(self, block: PhysicalTokenBlock): + self._blocks.append(block) + self._block_ids.append(block.block_number) + + def __len__(self) -> int: + return len(self._blocks) + + def __getitem__(self, key): + return self._blocks[key] + + def __setitem__(self, key, value): + if isinstance(key, slice): + blocks = value + self._blocks[key] = blocks + self._block_ids[key] = [b.block_number for b in blocks] + else: + block = value + self._blocks[key] = block + self._block_ids[key] = block.block_number + + def reset(self): + self._blocks = [] + self._block_ids = [] + + def copy(self) -> "BlockTable": + return BlockTable(self._blocks) + + def list(self) -> List[PhysicalTokenBlock]: + return self._blocks + + def ids(self) -> List[int]: + return self._block_ids diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index d81648caa5851..622aca66a96de 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -170,7 +170,7 @@ class UncachedBlockAllocator(BlockAllocatorBase): self.num_blocks = num_blocks # Initialize the free blocks. - self.free_blocks: BlockTable = [] + self.free_blocks: List[PhysicalTokenBlock] = [] for i in range(num_blocks): block = PhysicalTokenBlock(device=device, block_number=i, @@ -256,6 +256,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} + # Mapping: req_id -> BlockTable # Note that each SequenceGroup has a unique # request ID @@ -299,7 +300,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): # Allocate new physical token blocks that will store the prompt tokens. num_prompt_blocks = seq.n_blocks - block_table: BlockTable = [] + block_table: BlockTable = BlockTable() for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): @@ -326,15 +327,19 @@ class BlockSpaceManagerV1(BlockSpaceManager): # # NOTE: Here we assume that all sequences in the group have the same # decoder prompt. - seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + seq = wait_seqs[0] block_table: BlockTable = \ self._allocate_sequence(seq, seq_group.num_seqs(), is_encoder_decoder) # Assign the self-attention block tables for each sequence. - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): - self.block_tables[seq.seq_id] = block_table.copy() + if len(wait_seqs) == 1: + self.block_tables[wait_seqs[0].seq_id] = block_table + else: + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + self.block_tables[seq.seq_id] = block_table.copy() # Allocate encoder sequence if is_encoder_decoder: @@ -476,6 +481,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): return src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.copy() + # When using a sliding window, blocks will be eventually reused. # In this case the block tables will contain repeated blocks. # When forking, we must make sure that each block's `ref_count` @@ -527,7 +533,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): dest_allocator: BlockAllocatorBase, mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock]) -> BlockTable: - new_block_table = [] + new_block_table: BlockTable = BlockTable() for from_block in block_table: if from_block in mapping: @@ -553,8 +559,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): self.block_tables[seq.seq_id] = \ self._swap_block_table(self.block_tables[seq.seq_id], - self.cpu_allocator, - self.gpu_allocator, + self.cpu_allocator, self.gpu_allocator, mapping) if seq_group.is_encoder_decoder(): @@ -580,8 +585,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): self.block_tables[seq.seq_id] = \ self._swap_block_table(self.block_tables[seq.seq_id], - self.gpu_allocator, - self.cpu_allocator, + self.gpu_allocator, self.cpu_allocator, mapping) if seq_group.is_encoder_decoder(): @@ -636,8 +640,7 @@ class BlockSpaceManagerV1(BlockSpaceManager): self.cross_block_tables.clear() def get_block_table(self, seq: Sequence) -> List[int]: - block_table = self.block_tables[seq.seq_id] - return [block.block_number for block in block_table] + return self.block_tables[seq.seq_id].ids() def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: block_table = self.cross_block_tables[seq_group.request_id] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 950abfccba4c3..a40f6e2e248b9 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) +from vllm.utils import PyObjectCache logger = init_logger(__name__) @@ -176,10 +177,10 @@ class SchedulerRunningOutputs: enough memory, it can be preempted (for recompute) or swapped out. """ # Selected sequences that are running and in a decoding phase. - decode_seq_groups: List[SequenceGroup] + decode_seq_groups: List[ScheduledSequenceGroup] # Selected sequences that are running and in a prefill phase. # I.e., it means the prefill has been chunked. - prefill_seq_groups: List[SequenceGroup] + prefill_seq_groups: List[ScheduledSequenceGroup] # The preempted sequences. preempted: List[SequenceGroup] # Sequences that are swapped out. @@ -191,6 +192,10 @@ class SchedulerRunningOutputs: # The number of slots for lookahead decoding. num_lookahead_slots: int + # Optimization for fast-access to seq_group lists + decode_seq_groups_list: List[SequenceGroup] + prefill_seq_groups_list: List[SequenceGroup] + @classmethod def create_empty(cls) -> "SchedulerRunningOutputs": return SchedulerRunningOutputs( @@ -201,6 +206,8 @@ class SchedulerRunningOutputs: blocks_to_swap_out=[], blocks_to_copy=[], num_lookahead_slots=0, + decode_seq_groups_list=[], + prefill_seq_groups_list=[], ) @@ -259,6 +266,30 @@ class SchedulerPrefillOutputs: ) +def seq_group_metadata_builder(): + return SequenceGroupMetadata(request_id="", + is_prompt=False, + seq_data={}, + sampling_params=None, + block_tables={}) + + +def scheduler_running_outputs_builder(): + return SchedulerRunningOutputs(decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + prefill_seq_groups_list=[], + decode_seq_groups_list=[]) + + +def scheduled_seq_group_builder(): + return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + + class Scheduler: def __init__( @@ -331,6 +362,14 @@ class Scheduler: else 0) self.num_cumulative_preemption: int = 0 + # Used to cache python objects + self._seq_group_metadata_cache: PyObjectCache = PyObjectCache( + seq_group_metadata_builder) + self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( + scheduler_running_outputs_builder) + self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( + scheduled_seq_group_builder) + @property def lora_enabled(self) -> bool: return bool(self.lora_config) @@ -441,14 +480,30 @@ class Scheduler: Returns: SchedulerRunningOutputs. """ - # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: List[Tuple[int, int]] = [] - blocks_to_copy: List[Tuple[int, int]] = [] + ret: SchedulerRunningOutputs = \ + self._scheduler_running_outputs_cache.get_object() + ret.blocks_to_swap_out.clear() + ret.blocks_to_copy.clear() + ret.decode_seq_groups.clear() + ret.prefill_seq_groups.clear() + ret.preempted.clear() + ret.swapped_out.clear() - decode_seq_groups: List[ScheduledSequenceGroup] = [] - prefill_seq_groups: List[ScheduledSequenceGroup] = [] - preempted: List[SequenceGroup] = [] - swapped_out: List[SequenceGroup] = [] + ret.num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill=False) + + ret.decode_seq_groups_list.clear() + ret.prefill_seq_groups_list.clear() + + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out + blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy + + decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups + prefill_seq_groups: List[ + ScheduledSequenceGroup] = ret.prefill_seq_groups + preempted: List[SequenceGroup] = ret.preempted + swapped_out: List[SequenceGroup] = ret.swapped_out # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. @@ -497,15 +552,19 @@ class Scheduler: else: self._append_slots(seq_group, blocks_to_copy) is_prefill = seq_group.is_prefill() + + scheduled_seq_group: ScheduledSequenceGroup = \ + self._scheduled_seq_group_cache.get_object() + scheduled_seq_group.seq_group = seq_group if is_prefill: - prefill_seq_groups.append( - ScheduledSequenceGroup( - seq_group=seq_group, - token_chunk_size=num_running_tokens)) + scheduled_seq_group.token_chunk_size = num_running_tokens + prefill_seq_groups.append(scheduled_seq_group) + ret.prefill_seq_groups_list.append(seq_group) else: - decode_seq_groups.append( - ScheduledSequenceGroup(seq_group=seq_group, - token_chunk_size=1)) + scheduled_seq_group.token_chunk_size = 1 + decode_seq_groups.append(scheduled_seq_group) + ret.decode_seq_groups_list.append(seq_group) + budget.add_num_batched_tokens(seq_group.request_id, num_running_tokens) # OPTIMIZATION: Note that get_max_num_running_seqs is @@ -518,15 +577,10 @@ class Scheduler: if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - return SchedulerRunningOutputs( - decode_seq_groups=decode_seq_groups, - prefill_seq_groups=prefill_seq_groups, - preempted=preempted, - swapped_out=swapped_out, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False)) + self._scheduler_running_outputs_cache.reset() + self._scheduled_seq_group_cache.reset() + + return ret def _schedule_swapped( self, @@ -820,11 +874,15 @@ class Scheduler: # Update waiting requests. self.waiting.extendleft(running_scheduled.preempted) # Update new running requests. - self.running.extend([s.seq_group for s in prefills.seq_groups]) - self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups]) - self.running.extend( - [s.seq_group for s in swapped_in.decode_seq_groups]) + if len(prefills.seq_groups) > 0: + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + self.running.extend(running_scheduled.decode_seq_groups_list) + + if len(swapped_in.decode_seq_groups) > 0: + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) preempted = (len(running_scheduled.preempted) + @@ -834,18 +892,30 @@ class Scheduler: # doesn't allow chunked prefills. assert len(running_scheduled.prefill_seq_groups) == 0 assert len(swapped_in.prefill_seq_groups) == 0 + + # Merge lists + num_prefill_groups = len(prefills.seq_groups) + if num_prefill_groups > 0: + scheduled_seq_groups = prefills.seq_groups + scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + else: + scheduled_seq_groups = running_scheduled.decode_seq_groups + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) + + blocks_to_copy = running_scheduled.blocks_to_copy + blocks_to_copy.extend(swapped_in.blocks_to_copy) + + ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) + return SchedulerOutputs( - scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups), - num_prefill_groups=len(prefills.seq_groups), + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=running_scheduled.blocks_to_copy + - swapped_in.blocks_to_copy, - ignored_seq_groups=prefills.ignored_seq_groups + - swapped_in.infeasible_seq_groups, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, running_queue_size=len(self.running), preempted=preempted, @@ -963,6 +1033,9 @@ class Scheduler: scheduler_outputs = self._schedule() now = time.time() + if not self.cache_config.enable_prefix_caching: + common_computed_block_nums = [] + # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for i, scheduled_seq_group in enumerate( @@ -971,10 +1044,15 @@ class Scheduler: token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) + seq_group_metadata = self._seq_group_metadata_cache.get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + # seq_id -> SequenceData - seq_data: Dict[int, SequenceData] = {} + seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data # seq_id -> physical block numbers - block_tables: Dict[int, List[int]] = {} + block_tables: Dict[int, + List[int]] = seq_group_metadata.block_tables if seq_group.is_encoder_decoder(): # Encoder associated with SequenceGroup @@ -993,9 +1071,10 @@ class Scheduler: block_tables[seq_id] = self.block_manager.get_block_table(seq) self.block_manager.access_all_blocks_in_seq(seq, now) - common_computed_block_nums = ( - self.block_manager.get_common_computed_block_ids( - seq_group.get_seqs(status=SequenceStatus.RUNNING))) + if self.cache_config.enable_prefix_caching: + common_computed_block_nums = ( + self.block_manager.get_common_computed_block_ids( + seq_group.get_seqs(status=SequenceStatus.RUNNING))) do_sample = True if seq_group.is_prefill(): @@ -1014,7 +1093,8 @@ class Scheduler: # It assumes the scheduled_seq_groups is ordered by # prefill < decoding. is_prompt = seq_group.is_prefill() - seq_group_metadata = SequenceGroupMetadata( + + seq_group_metadata.__init__( request_id=seq_group.request_id, is_prompt=is_prompt, seq_data=seq_data, @@ -1045,6 +1125,8 @@ class Scheduler: self.block_manager.mark_blocks_as_computed( scheduled_seq_group.seq_group) + self._seq_group_metadata_cache.reset() + return seq_group_metadata_list, scheduler_outputs def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: @@ -1093,7 +1175,8 @@ class Scheduler: for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): cows = self.block_manager.append_slots(seq, num_lookahead_slots) - blocks_to_copy.extend(cows) + if len(cows) > 0: + blocks_to_copy.extend(cows) def _preempt( self, diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 5c767e22de4d0..7278c7fbe8bea 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,10 +1,12 @@ from vllm.model_executor.parameter import (BasevLLMParameter, PackedvLLMParameter) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingMetadataCache) from vllm.model_executor.utils import set_random_seed __all__ = [ "SamplingMetadata", + "SamplingMetadataCache", "set_random_seed", "BasevLLMParameter", "PackedvLLMParameter", diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 015e85b4ca81d..94b4b14416821 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -8,8 +8,9 @@ import torch from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData, SequenceGroupMetadata from vllm.triton_utils.sample import get_num_triton_sampler_splits -from vllm.utils import (async_tensor_h2d, is_pin_memory_available, - make_tensor_with_pad, maybe_expand_dim) +from vllm.utils import (PyObjectCache, async_tensor_h2d, + is_pin_memory_available, make_tensor_with_pad, + maybe_expand_dim) _SAMPLING_EPS = 1e-5 _SEED_0_REPLACEMENT = 3403598558 @@ -62,6 +63,39 @@ class SequenceGroupToSample: assert self.query_len is not None +def gen_seq_group_to_sample_builder(num_seqs: int): + return lambda: SequenceGroupToSample( + seq_ids=[0] * num_seqs, + sampling_params=None, + seq_data=None, # type: ignore + seq_len=0, + query_len=0, + generator=None, + is_prompt=True, + prompt_logprob_indices=[], + sample_indices=[]) + + +class SamplingMetadataCache: + """Used to cache SamplingMetadata objects between scheduler iterations + """ + + def __init__(self): + self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} + + def get_cached_seq_group_to_sample(self, num_seqs): + if num_seqs not in self._seq_group_to_sample_cache: + self._seq_group_to_sample_cache[num_seqs] = PyObjectCache( + gen_seq_group_to_sample_builder(num_seqs)) + + obj = self._seq_group_to_sample_cache[num_seqs].get_object() + return obj + + def reset(self): + for cache in self._seq_group_to_sample_cache.values(): + cache.reset() + + class SamplingMetadata: """Metadata for input sequences. Used in sampler. @@ -121,6 +155,7 @@ class SamplingMetadata: device: str, pin_memory: bool, generators: Optional[Dict[str, torch.Generator]] = None, + cache: Optional[SamplingMetadataCache] = None, ) -> "SamplingMetadata": ( seq_groups, @@ -128,7 +163,7 @@ class SamplingMetadata: categorized_sample_indices, num_prompts, ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, - device, generators) + device, generators, cache) selected_token_indices = async_tensor_h2d(selected_token_indices, dtype=torch.long, target_device=device, @@ -164,6 +199,7 @@ def _prepare_seq_groups( query_lens: Optional[List[int]], device: str, generators: Optional[Dict[str, torch.Generator]] = None, + cache: Optional[SamplingMetadataCache] = None, ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ SamplingType, List[Tuple[int, int]]], int]: """Prepare sequence groups and indices for sampling. @@ -210,15 +246,27 @@ def _prepare_seq_groups( num_prompts = 0 for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) + seq_ids = seq_group_metadata.seq_data.keys() + + if cache is not None: + sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids)) + + for j, seq_id in enumerate(seq_ids): + sample_obj.seq_ids[j] = seq_id + + sample_obj.prompt_logprob_indices.clear() + sample_obj.sample_indices.clear() + sampling_params = seq_group_metadata.sampling_params is_prompt = seq_group_metadata.is_prompt generator: Optional[torch.Generator] = None # If the current seq group is in decode stage, it is None. seq_len: Optional[int] = None query_len: Optional[int] = None - prompt_logprob_indices: List[int] = [] - sample_indices: List[int] = [] + prompt_logprob_indices: List[int] = \ + sample_obj.prompt_logprob_indices if cache is not None else [] + sample_indices: List[int] = \ + sample_obj.sample_indices if cache is not None else [] do_sample = seq_group_metadata.do_sample if seq_group_metadata.is_prompt: @@ -290,9 +338,16 @@ def _prepare_seq_groups( logit_idx += sample_len sample_idx += sample_len - seq_groups.append( - SequenceGroupToSample( - seq_ids=seq_ids, + if cache is not None: + sample_obj.sampling_params = sampling_params + sample_obj.seq_data = seq_group_metadata.seq_data + sample_obj.seq_len = seq_len + sample_obj.query_len = query_len + sample_obj.generator = generator + sample_obj.is_prompt = is_prompt + else: + sample_obj = SequenceGroupToSample( + seq_ids=list(seq_ids), sampling_params=sampling_params, seq_data=seq_group_metadata.seq_data, seq_len=seq_len, @@ -300,7 +355,13 @@ def _prepare_seq_groups( generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), - sample_indices=list(sample_indices))) + sample_indices=list(sample_indices)) + + seq_groups.append(sample_obj) + + if cache is not None: + cache.reset() + return (seq_groups, selected_token_indices, categorized_sample_indices, num_prompts) diff --git a/vllm/outputs.py b/vllm/outputs.py index 040f770814576..6e11ff841c62e 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -139,7 +139,7 @@ class RequestOutput: CompletionOutput( seqs.index(seq), seq.get_output_text_to_return(text_buffer_length), - seq.get_output_token_ids(), + seq.data._output_token_ids, # type: ignore seq.get_cumulative_logprob() if include_logprobs else None, seq.output_logprobs if include_logprobs else None, SequenceStatus.get_finished_reason(seq.status), diff --git a/vllm/sequence.py b/vllm/sequence.py index fbd148001cc7e..ba477efc54dd6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,7 +1,6 @@ """Sequence and its related classes.""" import copy import enum -import math from abc import ABC, abstractmethod from array import array from collections import defaultdict @@ -330,7 +329,7 @@ class Sequence: @property def n_blocks(self) -> int: - return math.ceil(self.get_len() / self.block_size) + return (self.get_len() + self.block_size - 1) // self.block_size @property def prompt(self) -> Optional[str]: @@ -514,7 +513,9 @@ class SequenceGroup: ) -> None: self.request_id = request_id self.seqs = seqs + self.is_single_seq = len(seqs) == 1 self.seqs_dict = {seq.seq_id: seq for seq in seqs} + self.sampling_params = sampling_params self.metrics = RequestMetrics(arrival_time=arrival_time, last_token_time=arrival_time, @@ -635,6 +636,10 @@ class SequenceGroup: ) -> List[Sequence]: if status is None: return self.seqs + + if self.is_single_seq: + return self.seqs if self.seqs[0].status == status else [] + return [seq for seq in self.seqs if seq.status == status] def is_encoder_decoder(self) -> bool: @@ -644,6 +649,9 @@ class SequenceGroup: return self.encoder_seq def get_unfinished_seqs(self) -> List[Sequence]: + if self.is_single_seq: + return self.seqs if not self.seqs[0].is_finished() else [] + return [seq for seq in self.seqs if not seq.is_finished()] def get_finished_seqs(self) -> List[Sequence]: @@ -668,12 +676,21 @@ class SequenceGroup: if status is None: return len(self.seqs) + if self.is_single_seq: + return 1 if self.seqs[0].status == status else 0 + return len(self.get_seqs(status)) def num_unfinished_seqs(self) -> int: + if self.is_single_seq: + return 1 if not self.seqs[0].is_finished() else 0 + return len(self.get_unfinished_seqs()) def num_finished_seqs(self) -> int: + if self.is_single_seq: + return 1 if self.seqs[0].is_finished() else 0 + return len(self.get_finished_seqs()) def find(self, seq_id: int) -> Sequence: @@ -686,12 +703,14 @@ class SequenceGroup: raise ValueError(f"Sequence {seq.seq_id} already exists.") self.seqs_dict[seq.seq_id] = seq self.seqs.append(seq) + self.is_single_seq = len(self.seqs) == 1 def remove(self, seq_id: int) -> None: seq = self.seqs_dict.pop(seq_id, None) if seq is None: raise ValueError(f"Sequence {seq_id} not found.") self.seqs.remove(seq) + self.is_single_seq = len(self.seqs) == 1 def is_finished(self) -> bool: return all(seq.is_finished() for seq in self.seqs) @@ -775,9 +794,10 @@ class SequenceGroupMetadata: # TODO: We should maintain this states out of the sequence group. self.num_speculative_tokens = None - if self._token_chunk_size is None: + if seq_data is not None and self._token_chunk_size is None: if is_prompt: - self._token_chunk_size = list(seq_data.values())[0].get_len() + self._token_chunk_size = next(iter( + seq_data.values())).get_len() else: self._token_chunk_size = 1 diff --git a/vllm/utils.py b/vllm/utils.py index f8251284af4aa..ecdd4760ee71f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -261,6 +261,44 @@ class LRUCache(Generic[T]): self.cache.clear() +class PyObjectCache: + """Used to cache python objects to avoid object allocations + across scheduler iterations. + """ + + def __init__(self, obj_builder): + self._obj_builder = obj_builder + self._index = 0 + + self._obj_cache = [] + for _ in range(128): + self._obj_cache.append(self._obj_builder()) + + def _grow_cache(self): + # Double the size of the cache + num_objs = len(self._obj_cache) + for _ in range(num_objs): + self._obj_cache.append(self._obj_builder()) + + def get_object(self): + """Returns a pre-allocated cached object. If there is not enough + objects, then the cache size will double. + """ + if self._index >= len(self._obj_cache): + self._grow_cache() + assert self._index < len(self._obj_cache) + + obj = self._obj_cache[self._index] + self._index += 1 + + return obj + + def reset(self): + """Makes all cached-objects available for the next scheduler iteration. + """ + self._index = 0 + + def is_hip() -> bool: return torch.version.hip is not None diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8b744a438e81a..913a08ce9f53d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,5 +1,6 @@ import dataclasses import gc +import itertools import time import warnings import weakref @@ -35,7 +36,7 @@ from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata +from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import (supports_lora, @@ -50,8 +51,8 @@ from vllm.prompt_adapter.worker_manager import ( from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists, - get_kv_cache_torch_dtype, is_hip, +from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, + flatten_2d_lists, get_kv_cache_torch_dtype, is_hip, is_pin_memory_available) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -178,6 +179,20 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): class InterDataForSeqGroup: """Intermediate data for the current sequence group.""" + def simple_reinit(self): + self.input_tokens[0].clear() # type: ignore + self.input_positions[0].clear() # type: ignore + self.seq_lens[0] = 0 # type: ignore + self.orig_seq_lens[0] = 0 # type: ignore + self.query_lens[0] = 0 # type: ignore + self.context_lens[0] = 0 # type: ignore + self.curr_sliding_window_blocks[0] = 0 # type: ignore + self.lora_index_mapping.clear() # type: ignore + self.lora_prompt_mapping.clear() # type: ignore + self.lora_requests.clear() # type: ignore + self.prompt_adapter_index_mapping.clear() # type: ignore + self.prompt_adapter_prompt_mapping.clear() # type: ignore + def __init__( self, *, @@ -220,35 +235,121 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): # Whether the prefix cache is hit (prefill only). prefix_cache_hit: bool = False, + reinit: bool = False, + reinit_use_defaults: bool = False, ): + if reinit: + assert len(self.seq_ids) == len(seq_ids) # type: ignore + for i, seq_id in enumerate(seq_ids): + self.seq_ids[i] = seq_id # type: ignore + else: + self.seq_ids = seq_ids + self.request_id = request_id - self.seq_ids = seq_ids self.is_prompt = is_prompt self.block_tables = block_tables self.computed_block_nums = computed_block_nums self.n_seqs = n_seqs - self.input_tokens = input_tokens or [] - self.input_positions = input_positions or [] - self.seq_lens = seq_lens or [] - self.orig_seq_lens = orig_seq_lens or [] - self.query_lens = query_lens or [] - self.context_lens = context_lens or [] - self.curr_sliding_window_blocks = curr_sliding_window_blocks or [] - self.lora_index_mapping = lora_index_mapping or [] - self.lora_prompt_mapping = lora_prompt_mapping or [] - self.lora_requests = lora_requests or set() + if reinit: + if len(self.seq_ids) == 1 and reinit_use_defaults: + self.simple_reinit() + else: + if input_tokens: + self.input_tokens = input_tokens + else: + for seq_id in range(len(self.seq_ids)): + self.input_tokens[seq_id].clear() + + if input_positions: + self.input_positions = input_positions + else: + for seq_id in range(len(self.seq_ids)): + self.input_positions[seq_id].clear() + + if seq_lens: + self.seq_lens = seq_lens + else: + for seq_id in range(len(self.seq_ids)): + self.seq_lens[seq_id] = 0 + + if orig_seq_lens: + self.orig_seq_lens = orig_seq_lens + else: + for seq_id in range(len(self.seq_ids)): + self.orig_seq_lens[seq_id] = 0 + + if query_lens: + self.query_lens = query_lens + else: + for seq_id in range(len(self.seq_ids)): + self.query_lens[seq_id] = 0 + + if context_lens: + self.context_lens = context_lens + else: + for seq_id in range(len(self.seq_ids)): + self.context_lens[seq_id] = 0 + + if curr_sliding_window_blocks: + self.curr_sliding_window_blocks = \ + curr_sliding_window_blocks + else: + for seq_id in range(len(self.seq_ids)): + self.curr_sliding_window_blocks[seq_id] = 0 + + if lora_index_mapping: + self.lora_index_mapping = lora_index_mapping + else: + self.lora_index_mapping.clear() + + if lora_prompt_mapping: + self.lora_prompt_mapping = lora_prompt_mapping + else: + self.lora_prompt_mapping.clear() + + if lora_requests: + self.lora_requests = lora_requests + else: + self.lora_requests.clear() + + if prompt_adapter_index_mapping: + self.prompt_adapter_index_mapping = \ + prompt_adapter_index_mapping + else: + self.prompt_adapter_index_mapping.clear() + + if prompt_adapter_prompt_mapping: + self.prompt_adapter_prompt_mapping = \ + prompt_adapter_prompt_mapping + else: + self.prompt_adapter_prompt_mapping.clear() + + else: + self.input_tokens = input_tokens or [] + self.input_positions = input_positions or [] + self.seq_lens = seq_lens or [] + self.orig_seq_lens = orig_seq_lens or [] + self.query_lens = query_lens or [] + self.context_lens = context_lens or [] + self.curr_sliding_window_blocks = \ + curr_sliding_window_blocks or [] + + self.lora_index_mapping = lora_index_mapping or [] + self.lora_prompt_mapping = lora_prompt_mapping or [] + self.lora_requests = lora_requests or set() + + self.prompt_adapter_index_mapping = ( + prompt_adapter_index_mapping or []) + self.prompt_adapter_prompt_mapping = ( + prompt_adapter_prompt_mapping or []) - self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping - or []) - self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping - or []) self.prompt_adapter_request = prompt_adapter_request - self.multi_modal_inputs = multi_modal_inputs self.prefix_cache_hit = prefix_cache_hit - self.__post_init__() + if not reinit: + self.__post_init__() def __post_init__(self): self.n_seqs = len(self.seq_ids) @@ -261,8 +362,36 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.context_lens = [0] * self.n_seqs self.curr_sliding_window_blocks = [0] * self.n_seqs - self.lora_index_mapping = [[] for _ in range(self.n_seqs)] - self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)] + self.lora_index_mapping = [] + self.lora_prompt_mapping = [] + + def gen_inter_data_builder(self, num_seqs: int): + return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( + request_id="", + seq_ids=[0] * num_seqs, + is_prompt=True, + block_tables=None, + computed_block_nums=[]) + + def init_cached_inter_data(self, *args, **kwargs): + assert len(args) == 0 + assert "seq_ids" in kwargs + seq_ids = kwargs["seq_ids"] + num_seqs = len(seq_ids) + + # The inter-data cache is per model_runner + inter_data_cache = self.runner.inter_data_cache + if num_seqs not in inter_data_cache: + inter_data_cache[num_seqs] = PyObjectCache( + self.gen_inter_data_builder(num_seqs)) + + obj = inter_data_cache[num_seqs].get_object() + obj.__init__(*args, **kwargs) + return obj + + def reset_cached_inter_data(self): + for cache in self.runner.inter_data_cache.values(): + cache.reset() def __init__(self, runner: "GPUModelRunnerBase", @@ -337,17 +466,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): # Compute tokens. if inter_data.is_prompt: - tokens = seq_data.get_token_ids()[context_len:seq_len] + tokens = seq_data.get_token_ids() + if context_len != 0 or seq_len < len(tokens): + tokens = tokens[context_len:seq_len] else: # Optimization. get_token_ids requires the entire copy of # tokens. - tokens = [seq_data.get_last_token_id()] + tokens = seq_data.get_last_token_id() inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len - inter_data.input_tokens[seq_idx] = tokens - inter_data.input_positions[seq_idx] = list(range(context_len, seq_len)) + + if isinstance(tokens, list): + inter_data.input_tokens[seq_idx].extend(tokens) + else: + inter_data.input_tokens[seq_idx].append(tokens) + + if (seq_len - context_len) == 1: + inter_data.input_positions[seq_idx].append(seq_len - 1) + else: + inter_data.input_positions[seq_idx].extend( + range(context_len, seq_len)) + inter_data.query_lens[ seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 @@ -471,7 +612,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): """Add a sequence group to the builder.""" - seq_ids = list(seq_group_metadata.seq_data.keys()) + seq_ids = seq_group_metadata.seq_data.keys() n_seqs = len(seq_ids) is_prompt = seq_group_metadata.is_prompt @@ -479,12 +620,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): assert n_seqs == 1 self.decode_only = False - inter_data = self.InterDataForSeqGroup( + inter_data = self.init_cached_inter_data( request_id=seq_group_metadata.request_id, seq_ids=seq_ids, is_prompt=is_prompt, block_tables=seq_group_metadata.block_tables, - computed_block_nums=seq_group_metadata.computed_block_nums) + computed_block_nums=seq_group_metadata.computed_block_nums, + reinit=True, + reinit_use_defaults=True) + self.inter_data_list.append(inter_data) for seq_idx in range(n_seqs): @@ -504,18 +648,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): create on-device tensors. """ # Combine and flatten intermediate data. - input_tokens = flatten_2d_lists([ - flatten_2d_lists(inter_data.input_tokens) - for inter_data in self.inter_data_list - ]) + input_tokens = [] + for inter_data in self.inter_data_list: + for cur_input_tokens in inter_data.input_tokens: + input_tokens.extend(cur_input_tokens) + if not input_tokens: # This may happen when all prefill requests hit # prefix caching and there is no decode request. return self.model_input_cls() - input_positions = flatten_2d_lists([ - flatten_2d_lists(inter_data.input_positions) - for inter_data in self.inter_data_list - ]) + + input_positions = [] + for inter_data in self.inter_data_list: + for cur_input_positions in inter_data.input_positions: + input_positions.extend(cur_input_positions) + seq_lens = [] max_decode_seq_len = 0 for inter_data in self.inter_data_list: @@ -523,8 +670,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): if not inter_data.is_prompt: max_decode_seq_len = max(max_decode_seq_len, max(inter_data.seq_lens)) - query_lens = flatten_2d_lists( - [inter_data.query_lens for inter_data in self.inter_data_list]) + query_lens = [] + for inter_data in self.inter_data_list: + query_lens.extend(inter_data.query_lens) + # Mapping from request IDs to sequence IDs. Used for Jamba models # that manages the cache by itself. request_ids_to_seq_ids = { @@ -547,8 +696,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): batch_size = graph_batch_size # Tokens and positions. - input_tokens.extend([0] * cuda_graph_pad_size) - input_positions.extend([0] * cuda_graph_pad_size) + if cuda_graph_pad_size: + input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) + input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) assert self.runner.device is not None input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, @@ -558,7 +708,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.runner.pin_memory) # Sequence and query lengths. - seq_lens.extend([1] * cuda_graph_pad_size) + if cuda_graph_pad_size: + seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) # Attention metadata. attn_metadata = self.attn_metadata_builder.build( @@ -574,11 +725,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): flatten_2d_lists(inter_data.lora_index_mapping) for inter_data in self.inter_data_list ]) - lora_index_mapping.extend([0] * cuda_graph_pad_size) + if cuda_graph_pad_size: + lora_index_mapping.extend( + itertools.repeat(0, cuda_graph_pad_size)) lora_prompt_mapping = flatten_2d_lists([ flatten_2d_lists(inter_data.lora_prompt_mapping) for inter_data in self.inter_data_list ]) + lora_mapping = LoRAMapping( **dict(index_mapping=lora_index_mapping, prompt_mapping=lora_prompt_mapping, @@ -595,7 +749,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): inter_data.prompt_adapter_index_mapping for inter_data in self.inter_data_list ]) - prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size) + if cuda_graph_pad_size: + prompt_adapter_index_mapping.extend( + itertools.repeat(0, cuda_graph_pad_size)) prompt_adapter_prompt_mapping = flatten_2d_lists([ inter_data.prompt_adapter_prompt_mapping for inter_data in self.inter_data_list @@ -717,6 +873,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) + # Used to cache python objects + self.inter_data_cache: Dict[int, PyObjectCache] = {} + self.sampling_metadata_cache: SamplingMetadataCache = \ + SamplingMetadataCache() + def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with CudaMemoryProfiler() as m: @@ -843,6 +1004,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) for seq_group_metadata in seq_group_metadata_list: builder.add_seq_group(seq_group_metadata) + + builder.reset_cached_inter_data() + return builder.build() # type: ignore @torch.inference_mode() @@ -1276,7 +1440,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, model_input.seq_lens, model_input.query_lens, self.device, self.pin_memory, - generators) + generators, self.sampling_metadata_cache) else: sampling_metadata = None is_prompt = (seq_group_metadata_list[0].is_prompt