mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-08 06:07:02 +08:00
[Performance] Optimize e2e overheads: Reduce python allocations (#7162)
This commit is contained in:
parent
73388c07a4
commit
e02ac55617
@ -259,7 +259,11 @@ class FlashAttentionMetadataBuilder(
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
block_table = block_tables[seq_id][-curr_sliding_window_block:]
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
|
||||
@ -68,13 +68,21 @@ def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int],
|
||||
# tokens are masked and the slot mapping will be
|
||||
# [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
block_table = block_tables[seq_id]
|
||||
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
|
||||
for i in range(max(start_idx, context_len), seq_len):
|
||||
|
||||
def add_slot(i):
|
||||
block_number = block_table[i // block_size]
|
||||
block_offset = i % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if start_idx == 0 and (seq_len - context_len) == 1:
|
||||
# Optimization for common-case of decoding next token
|
||||
add_slot(seq_len - 1)
|
||||
else:
|
||||
slot_mapping.extend([PAD_SLOT_ID] * max(0, start_idx - context_len))
|
||||
for i in range(max(start_idx, context_len), seq_len):
|
||||
add_slot(i)
|
||||
|
||||
|
||||
TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata')
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""Token blocks."""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.utils import Device
|
||||
|
||||
@ -37,5 +37,47 @@ class PhysicalTokenBlock:
|
||||
f'computed={self.computed})')
|
||||
|
||||
|
||||
# Mapping: logical block number -> physical block.
|
||||
BlockTable = List[PhysicalTokenBlock]
|
||||
class BlockTable:
|
||||
"""Holds a list of blocks with caching of their associated block_ids
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None):
|
||||
self._blocks: List[PhysicalTokenBlock] = []
|
||||
self._block_ids: List[int] = []
|
||||
|
||||
if blocks is not None:
|
||||
for block in blocks:
|
||||
self.append(block)
|
||||
|
||||
def append(self, block: PhysicalTokenBlock):
|
||||
self._blocks.append(block)
|
||||
self._block_ids.append(block.block_number)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._blocks)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._blocks[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if isinstance(key, slice):
|
||||
blocks = value
|
||||
self._blocks[key] = blocks
|
||||
self._block_ids[key] = [b.block_number for b in blocks]
|
||||
else:
|
||||
block = value
|
||||
self._blocks[key] = block
|
||||
self._block_ids[key] = block.block_number
|
||||
|
||||
def reset(self):
|
||||
self._blocks = []
|
||||
self._block_ids = []
|
||||
|
||||
def copy(self) -> "BlockTable":
|
||||
return BlockTable(self._blocks)
|
||||
|
||||
def list(self) -> List[PhysicalTokenBlock]:
|
||||
return self._blocks
|
||||
|
||||
def ids(self) -> List[int]:
|
||||
return self._block_ids
|
||||
|
||||
@ -170,7 +170,7 @@ class UncachedBlockAllocator(BlockAllocatorBase):
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: BlockTable = []
|
||||
self.free_blocks: List[PhysicalTokenBlock] = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
@ -256,6 +256,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
Device.CPU, block_size, num_cpu_blocks)
|
||||
# Mapping: seq_id -> BlockTable.
|
||||
self.block_tables: Dict[int, BlockTable] = {}
|
||||
|
||||
# Mapping: req_id -> BlockTable
|
||||
# Note that each SequenceGroup has a unique
|
||||
# request ID
|
||||
@ -299,7 +300,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
num_prompt_blocks = seq.n_blocks
|
||||
|
||||
block_table: BlockTable = []
|
||||
block_table: BlockTable = BlockTable()
|
||||
for logical_idx in range(num_prompt_blocks):
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
@ -326,15 +327,19 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
#
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# decoder prompt.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
seq = wait_seqs[0]
|
||||
block_table: BlockTable = \
|
||||
self._allocate_sequence(seq,
|
||||
seq_group.num_seqs(),
|
||||
is_encoder_decoder)
|
||||
|
||||
# Assign the self-attention block tables for each sequence.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
if len(wait_seqs) == 1:
|
||||
self.block_tables[wait_seqs[0].seq_id] = block_table
|
||||
else:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
|
||||
# Allocate encoder sequence
|
||||
if is_encoder_decoder:
|
||||
@ -476,6 +481,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
return
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.copy()
|
||||
|
||||
# When using a sliding window, blocks will be eventually reused.
|
||||
# In this case the block tables will contain repeated blocks.
|
||||
# When forking, we must make sure that each block's `ref_count`
|
||||
@ -527,7 +533,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
dest_allocator: BlockAllocatorBase,
|
||||
mapping: Dict[PhysicalTokenBlock,
|
||||
PhysicalTokenBlock]) -> BlockTable:
|
||||
new_block_table = []
|
||||
new_block_table: BlockTable = BlockTable()
|
||||
|
||||
for from_block in block_table:
|
||||
if from_block in mapping:
|
||||
@ -553,8 +559,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
self.block_tables[seq.seq_id] = \
|
||||
self._swap_block_table(self.block_tables[seq.seq_id],
|
||||
self.cpu_allocator,
|
||||
self.gpu_allocator,
|
||||
self.cpu_allocator, self.gpu_allocator,
|
||||
mapping)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
@ -580,8 +585,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
self.block_tables[seq.seq_id] = \
|
||||
self._swap_block_table(self.block_tables[seq.seq_id],
|
||||
self.gpu_allocator,
|
||||
self.cpu_allocator,
|
||||
self.gpu_allocator, self.cpu_allocator,
|
||||
mapping)
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
@ -636,8 +640,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
self.cross_block_tables.clear()
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
return [block.block_number for block in block_table]
|
||||
return self.block_tables[seq.seq_id].ids()
|
||||
|
||||
def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]:
|
||||
block_table = self.cross_block_tables[seq_group.request_id]
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
from vllm.utils import PyObjectCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -176,10 +177,10 @@ class SchedulerRunningOutputs:
|
||||
enough memory, it can be preempted (for recompute) or swapped out.
|
||||
"""
|
||||
# Selected sequences that are running and in a decoding phase.
|
||||
decode_seq_groups: List[SequenceGroup]
|
||||
decode_seq_groups: List[ScheduledSequenceGroup]
|
||||
# Selected sequences that are running and in a prefill phase.
|
||||
# I.e., it means the prefill has been chunked.
|
||||
prefill_seq_groups: List[SequenceGroup]
|
||||
prefill_seq_groups: List[ScheduledSequenceGroup]
|
||||
# The preempted sequences.
|
||||
preempted: List[SequenceGroup]
|
||||
# Sequences that are swapped out.
|
||||
@ -191,6 +192,10 @@ class SchedulerRunningOutputs:
|
||||
# The number of slots for lookahead decoding.
|
||||
num_lookahead_slots: int
|
||||
|
||||
# Optimization for fast-access to seq_group lists
|
||||
decode_seq_groups_list: List[SequenceGroup]
|
||||
prefill_seq_groups_list: List[SequenceGroup]
|
||||
|
||||
@classmethod
|
||||
def create_empty(cls) -> "SchedulerRunningOutputs":
|
||||
return SchedulerRunningOutputs(
|
||||
@ -201,6 +206,8 @@ class SchedulerRunningOutputs:
|
||||
blocks_to_swap_out=[],
|
||||
blocks_to_copy=[],
|
||||
num_lookahead_slots=0,
|
||||
decode_seq_groups_list=[],
|
||||
prefill_seq_groups_list=[],
|
||||
)
|
||||
|
||||
|
||||
@ -259,6 +266,30 @@ class SchedulerPrefillOutputs:
|
||||
)
|
||||
|
||||
|
||||
def seq_group_metadata_builder():
|
||||
return SequenceGroupMetadata(request_id="",
|
||||
is_prompt=False,
|
||||
seq_data={},
|
||||
sampling_params=None,
|
||||
block_tables={})
|
||||
|
||||
|
||||
def scheduler_running_outputs_builder():
|
||||
return SchedulerRunningOutputs(decode_seq_groups=[],
|
||||
prefill_seq_groups=[],
|
||||
preempted=[],
|
||||
swapped_out=[],
|
||||
blocks_to_swap_out=[],
|
||||
blocks_to_copy=[],
|
||||
num_lookahead_slots=0,
|
||||
prefill_seq_groups_list=[],
|
||||
decode_seq_groups_list=[])
|
||||
|
||||
|
||||
def scheduled_seq_group_builder():
|
||||
return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
def __init__(
|
||||
@ -331,6 +362,14 @@ class Scheduler:
|
||||
else 0)
|
||||
self.num_cumulative_preemption: int = 0
|
||||
|
||||
# Used to cache python objects
|
||||
self._seq_group_metadata_cache: PyObjectCache = PyObjectCache(
|
||||
seq_group_metadata_builder)
|
||||
self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache(
|
||||
scheduler_running_outputs_builder)
|
||||
self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache(
|
||||
scheduled_seq_group_builder)
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
return bool(self.lora_config)
|
||||
@ -441,14 +480,30 @@ class Scheduler:
|
||||
Returns:
|
||||
SchedulerRunningOutputs.
|
||||
"""
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
ret: SchedulerRunningOutputs = \
|
||||
self._scheduler_running_outputs_cache.get_object()
|
||||
ret.blocks_to_swap_out.clear()
|
||||
ret.blocks_to_copy.clear()
|
||||
ret.decode_seq_groups.clear()
|
||||
ret.prefill_seq_groups.clear()
|
||||
ret.preempted.clear()
|
||||
ret.swapped_out.clear()
|
||||
|
||||
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
prefill_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
preempted: List[SequenceGroup] = []
|
||||
swapped_out: List[SequenceGroup] = []
|
||||
ret.num_lookahead_slots = self._get_num_lookahead_slots(
|
||||
is_prefill=False)
|
||||
|
||||
ret.decode_seq_groups_list.clear()
|
||||
ret.prefill_seq_groups_list.clear()
|
||||
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out
|
||||
blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy
|
||||
|
||||
decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups
|
||||
prefill_seq_groups: List[
|
||||
ScheduledSequenceGroup] = ret.prefill_seq_groups
|
||||
preempted: List[SequenceGroup] = ret.preempted
|
||||
swapped_out: List[SequenceGroup] = ret.swapped_out
|
||||
|
||||
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||
# to keep all the sequence groups in the RUNNING state.
|
||||
@ -497,15 +552,19 @@ class Scheduler:
|
||||
else:
|
||||
self._append_slots(seq_group, blocks_to_copy)
|
||||
is_prefill = seq_group.is_prefill()
|
||||
|
||||
scheduled_seq_group: ScheduledSequenceGroup = \
|
||||
self._scheduled_seq_group_cache.get_object()
|
||||
scheduled_seq_group.seq_group = seq_group
|
||||
if is_prefill:
|
||||
prefill_seq_groups.append(
|
||||
ScheduledSequenceGroup(
|
||||
seq_group=seq_group,
|
||||
token_chunk_size=num_running_tokens))
|
||||
scheduled_seq_group.token_chunk_size = num_running_tokens
|
||||
prefill_seq_groups.append(scheduled_seq_group)
|
||||
ret.prefill_seq_groups_list.append(seq_group)
|
||||
else:
|
||||
decode_seq_groups.append(
|
||||
ScheduledSequenceGroup(seq_group=seq_group,
|
||||
token_chunk_size=1))
|
||||
scheduled_seq_group.token_chunk_size = 1
|
||||
decode_seq_groups.append(scheduled_seq_group)
|
||||
ret.decode_seq_groups_list.append(seq_group)
|
||||
|
||||
budget.add_num_batched_tokens(seq_group.request_id,
|
||||
num_running_tokens)
|
||||
# OPTIMIZATION: Note that get_max_num_running_seqs is
|
||||
@ -518,15 +577,10 @@ class Scheduler:
|
||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||
curr_loras.add(seq_group.lora_int_id)
|
||||
|
||||
return SchedulerRunningOutputs(
|
||||
decode_seq_groups=decode_seq_groups,
|
||||
prefill_seq_groups=prefill_seq_groups,
|
||||
preempted=preempted,
|
||||
swapped_out=swapped_out,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=False))
|
||||
self._scheduler_running_outputs_cache.reset()
|
||||
self._scheduled_seq_group_cache.reset()
|
||||
|
||||
return ret
|
||||
|
||||
def _schedule_swapped(
|
||||
self,
|
||||
@ -820,11 +874,15 @@ class Scheduler:
|
||||
# Update waiting requests.
|
||||
self.waiting.extendleft(running_scheduled.preempted)
|
||||
# Update new running requests.
|
||||
self.running.extend([s.seq_group for s in prefills.seq_groups])
|
||||
self.running.extend(
|
||||
[s.seq_group for s in running_scheduled.decode_seq_groups])
|
||||
self.running.extend(
|
||||
[s.seq_group for s in swapped_in.decode_seq_groups])
|
||||
if len(prefills.seq_groups) > 0:
|
||||
self.running.extend([s.seq_group for s in prefills.seq_groups])
|
||||
|
||||
self.running.extend(running_scheduled.decode_seq_groups_list)
|
||||
|
||||
if len(swapped_in.decode_seq_groups) > 0:
|
||||
self.running.extend(
|
||||
[s.seq_group for s in swapped_in.decode_seq_groups])
|
||||
|
||||
# Update swapped requests.
|
||||
self.swapped.extend(running_scheduled.swapped_out)
|
||||
preempted = (len(running_scheduled.preempted) +
|
||||
@ -834,18 +892,30 @@ class Scheduler:
|
||||
# doesn't allow chunked prefills.
|
||||
assert len(running_scheduled.prefill_seq_groups) == 0
|
||||
assert len(swapped_in.prefill_seq_groups) == 0
|
||||
|
||||
# Merge lists
|
||||
num_prefill_groups = len(prefills.seq_groups)
|
||||
if num_prefill_groups > 0:
|
||||
scheduled_seq_groups = prefills.seq_groups
|
||||
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
|
||||
else:
|
||||
scheduled_seq_groups = running_scheduled.decode_seq_groups
|
||||
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
|
||||
|
||||
blocks_to_copy = running_scheduled.blocks_to_copy
|
||||
blocks_to_copy.extend(swapped_in.blocks_to_copy)
|
||||
|
||||
ignored_seq_groups = prefills.ignored_seq_groups
|
||||
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
|
||||
|
||||
return SchedulerOutputs(
|
||||
scheduled_seq_groups=(prefills.seq_groups +
|
||||
running_scheduled.decode_seq_groups +
|
||||
swapped_in.decode_seq_groups),
|
||||
num_prefill_groups=len(prefills.seq_groups),
|
||||
scheduled_seq_groups=scheduled_seq_groups,
|
||||
num_prefill_groups=num_prefill_groups,
|
||||
num_batched_tokens=budget.num_batched_tokens,
|
||||
blocks_to_swap_in=swapped_in.blocks_to_swap_in,
|
||||
blocks_to_swap_out=running_scheduled.blocks_to_swap_out,
|
||||
blocks_to_copy=running_scheduled.blocks_to_copy +
|
||||
swapped_in.blocks_to_copy,
|
||||
ignored_seq_groups=prefills.ignored_seq_groups +
|
||||
swapped_in.infeasible_seq_groups,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||
running_queue_size=len(self.running),
|
||||
preempted=preempted,
|
||||
@ -963,6 +1033,9 @@ class Scheduler:
|
||||
scheduler_outputs = self._schedule()
|
||||
now = time.time()
|
||||
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
common_computed_block_nums = []
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for i, scheduled_seq_group in enumerate(
|
||||
@ -971,10 +1044,15 @@ class Scheduler:
|
||||
token_chunk_size = scheduled_seq_group.token_chunk_size
|
||||
seq_group.maybe_set_first_scheduled_time(now)
|
||||
|
||||
seq_group_metadata = self._seq_group_metadata_cache.get_object()
|
||||
seq_group_metadata.seq_data.clear()
|
||||
seq_group_metadata.block_tables.clear()
|
||||
|
||||
# seq_id -> SequenceData
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
seq_data: Dict[int, SequenceData] = seq_group_metadata.seq_data
|
||||
# seq_id -> physical block numbers
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
block_tables: Dict[int,
|
||||
List[int]] = seq_group_metadata.block_tables
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
# Encoder associated with SequenceGroup
|
||||
@ -993,9 +1071,10 @@ class Scheduler:
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
self.block_manager.access_all_blocks_in_seq(seq, now)
|
||||
|
||||
common_computed_block_nums = (
|
||||
self.block_manager.get_common_computed_block_ids(
|
||||
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
common_computed_block_nums = (
|
||||
self.block_manager.get_common_computed_block_ids(
|
||||
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
|
||||
|
||||
do_sample = True
|
||||
if seq_group.is_prefill():
|
||||
@ -1014,7 +1093,8 @@ class Scheduler:
|
||||
# It assumes the scheduled_seq_groups is ordered by
|
||||
# prefill < decoding.
|
||||
is_prompt = seq_group.is_prefill()
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
|
||||
seq_group_metadata.__init__(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=is_prompt,
|
||||
seq_data=seq_data,
|
||||
@ -1045,6 +1125,8 @@ class Scheduler:
|
||||
self.block_manager.mark_blocks_as_computed(
|
||||
scheduled_seq_group.seq_group)
|
||||
|
||||
self._seq_group_metadata_cache.reset()
|
||||
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
@ -1093,7 +1175,8 @@ class Scheduler:
|
||||
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
|
||||
blocks_to_copy.extend(cows)
|
||||
if len(cows) > 0:
|
||||
blocks_to_copy.extend(cows)
|
||||
|
||||
def _preempt(
|
||||
self,
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingMetadataCache)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
__all__ = [
|
||||
"SamplingMetadata",
|
||||
"SamplingMetadataCache",
|
||||
"set_random_seed",
|
||||
"BasevLLMParameter",
|
||||
"PackedvLLMParameter",
|
||||
|
||||
@ -8,8 +8,9 @@ import torch
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
||||
from vllm.triton_utils.sample import get_num_triton_sampler_splits
|
||||
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||
make_tensor_with_pad, maybe_expand_dim)
|
||||
from vllm.utils import (PyObjectCache, async_tensor_h2d,
|
||||
is_pin_memory_available, make_tensor_with_pad,
|
||||
maybe_expand_dim)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
_SEED_0_REPLACEMENT = 3403598558
|
||||
@ -62,6 +63,39 @@ class SequenceGroupToSample:
|
||||
assert self.query_len is not None
|
||||
|
||||
|
||||
def gen_seq_group_to_sample_builder(num_seqs: int):
|
||||
return lambda: SequenceGroupToSample(
|
||||
seq_ids=[0] * num_seqs,
|
||||
sampling_params=None,
|
||||
seq_data=None, # type: ignore
|
||||
seq_len=0,
|
||||
query_len=0,
|
||||
generator=None,
|
||||
is_prompt=True,
|
||||
prompt_logprob_indices=[],
|
||||
sample_indices=[])
|
||||
|
||||
|
||||
class SamplingMetadataCache:
|
||||
"""Used to cache SamplingMetadata objects between scheduler iterations
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
|
||||
|
||||
def get_cached_seq_group_to_sample(self, num_seqs):
|
||||
if num_seqs not in self._seq_group_to_sample_cache:
|
||||
self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
|
||||
gen_seq_group_to_sample_builder(num_seqs))
|
||||
|
||||
obj = self._seq_group_to_sample_cache[num_seqs].get_object()
|
||||
return obj
|
||||
|
||||
def reset(self):
|
||||
for cache in self._seq_group_to_sample_cache.values():
|
||||
cache.reset()
|
||||
|
||||
|
||||
class SamplingMetadata:
|
||||
"""Metadata for input sequences. Used in sampler.
|
||||
|
||||
@ -121,6 +155,7 @@ class SamplingMetadata:
|
||||
device: str,
|
||||
pin_memory: bool,
|
||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||
cache: Optional[SamplingMetadataCache] = None,
|
||||
) -> "SamplingMetadata":
|
||||
(
|
||||
seq_groups,
|
||||
@ -128,7 +163,7 @@ class SamplingMetadata:
|
||||
categorized_sample_indices,
|
||||
num_prompts,
|
||||
) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens,
|
||||
device, generators)
|
||||
device, generators, cache)
|
||||
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=device,
|
||||
@ -164,6 +199,7 @@ def _prepare_seq_groups(
|
||||
query_lens: Optional[List[int]],
|
||||
device: str,
|
||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||
cache: Optional[SamplingMetadataCache] = None,
|
||||
) -> Tuple[List[SequenceGroupToSample], List[int], Dict[
|
||||
SamplingType, List[Tuple[int, int]]], int]:
|
||||
"""Prepare sequence groups and indices for sampling.
|
||||
@ -210,15 +246,27 @@ def _prepare_seq_groups(
|
||||
num_prompts = 0
|
||||
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
seq_ids = seq_group_metadata.seq_data.keys()
|
||||
|
||||
if cache is not None:
|
||||
sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
|
||||
|
||||
for j, seq_id in enumerate(seq_ids):
|
||||
sample_obj.seq_ids[j] = seq_id
|
||||
|
||||
sample_obj.prompt_logprob_indices.clear()
|
||||
sample_obj.sample_indices.clear()
|
||||
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
generator: Optional[torch.Generator] = None
|
||||
# If the current seq group is in decode stage, it is None.
|
||||
seq_len: Optional[int] = None
|
||||
query_len: Optional[int] = None
|
||||
prompt_logprob_indices: List[int] = []
|
||||
sample_indices: List[int] = []
|
||||
prompt_logprob_indices: List[int] = \
|
||||
sample_obj.prompt_logprob_indices if cache is not None else []
|
||||
sample_indices: List[int] = \
|
||||
sample_obj.sample_indices if cache is not None else []
|
||||
do_sample = seq_group_metadata.do_sample
|
||||
|
||||
if seq_group_metadata.is_prompt:
|
||||
@ -290,9 +338,16 @@ def _prepare_seq_groups(
|
||||
logit_idx += sample_len
|
||||
sample_idx += sample_len
|
||||
|
||||
seq_groups.append(
|
||||
SequenceGroupToSample(
|
||||
seq_ids=seq_ids,
|
||||
if cache is not None:
|
||||
sample_obj.sampling_params = sampling_params
|
||||
sample_obj.seq_data = seq_group_metadata.seq_data
|
||||
sample_obj.seq_len = seq_len
|
||||
sample_obj.query_len = query_len
|
||||
sample_obj.generator = generator
|
||||
sample_obj.is_prompt = is_prompt
|
||||
else:
|
||||
sample_obj = SequenceGroupToSample(
|
||||
seq_ids=list(seq_ids),
|
||||
sampling_params=sampling_params,
|
||||
seq_data=seq_group_metadata.seq_data,
|
||||
seq_len=seq_len,
|
||||
@ -300,7 +355,13 @@ def _prepare_seq_groups(
|
||||
generator=generator,
|
||||
is_prompt=is_prompt,
|
||||
prompt_logprob_indices=list(prompt_logprob_indices),
|
||||
sample_indices=list(sample_indices)))
|
||||
sample_indices=list(sample_indices))
|
||||
|
||||
seq_groups.append(sample_obj)
|
||||
|
||||
if cache is not None:
|
||||
cache.reset()
|
||||
|
||||
return (seq_groups, selected_token_indices, categorized_sample_indices,
|
||||
num_prompts)
|
||||
|
||||
|
||||
@ -139,7 +139,7 @@ class RequestOutput:
|
||||
CompletionOutput(
|
||||
seqs.index(seq),
|
||||
seq.get_output_text_to_return(text_buffer_length),
|
||||
seq.get_output_token_ids(),
|
||||
seq.data._output_token_ids, # type: ignore
|
||||
seq.get_cumulative_logprob() if include_logprobs else None,
|
||||
seq.output_logprobs if include_logprobs else None,
|
||||
SequenceStatus.get_finished_reason(seq.status),
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""Sequence and its related classes."""
|
||||
import copy
|
||||
import enum
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
@ -330,7 +329,7 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def n_blocks(self) -> int:
|
||||
return math.ceil(self.get_len() / self.block_size)
|
||||
return (self.get_len() + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
@ -514,7 +513,9 @@ class SequenceGroup:
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs = seqs
|
||||
self.is_single_seq = len(seqs) == 1
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
|
||||
self.sampling_params = sampling_params
|
||||
self.metrics = RequestMetrics(arrival_time=arrival_time,
|
||||
last_token_time=arrival_time,
|
||||
@ -635,6 +636,10 @@ class SequenceGroup:
|
||||
) -> List[Sequence]:
|
||||
if status is None:
|
||||
return self.seqs
|
||||
|
||||
if self.is_single_seq:
|
||||
return self.seqs if self.seqs[0].status == status else []
|
||||
|
||||
return [seq for seq in self.seqs if seq.status == status]
|
||||
|
||||
def is_encoder_decoder(self) -> bool:
|
||||
@ -644,6 +649,9 @@ class SequenceGroup:
|
||||
return self.encoder_seq
|
||||
|
||||
def get_unfinished_seqs(self) -> List[Sequence]:
|
||||
if self.is_single_seq:
|
||||
return self.seqs if not self.seqs[0].is_finished() else []
|
||||
|
||||
return [seq for seq in self.seqs if not seq.is_finished()]
|
||||
|
||||
def get_finished_seqs(self) -> List[Sequence]:
|
||||
@ -668,12 +676,21 @@ class SequenceGroup:
|
||||
if status is None:
|
||||
return len(self.seqs)
|
||||
|
||||
if self.is_single_seq:
|
||||
return 1 if self.seqs[0].status == status else 0
|
||||
|
||||
return len(self.get_seqs(status))
|
||||
|
||||
def num_unfinished_seqs(self) -> int:
|
||||
if self.is_single_seq:
|
||||
return 1 if not self.seqs[0].is_finished() else 0
|
||||
|
||||
return len(self.get_unfinished_seqs())
|
||||
|
||||
def num_finished_seqs(self) -> int:
|
||||
if self.is_single_seq:
|
||||
return 1 if self.seqs[0].is_finished() else 0
|
||||
|
||||
return len(self.get_finished_seqs())
|
||||
|
||||
def find(self, seq_id: int) -> Sequence:
|
||||
@ -686,12 +703,14 @@ class SequenceGroup:
|
||||
raise ValueError(f"Sequence {seq.seq_id} already exists.")
|
||||
self.seqs_dict[seq.seq_id] = seq
|
||||
self.seqs.append(seq)
|
||||
self.is_single_seq = len(self.seqs) == 1
|
||||
|
||||
def remove(self, seq_id: int) -> None:
|
||||
seq = self.seqs_dict.pop(seq_id, None)
|
||||
if seq is None:
|
||||
raise ValueError(f"Sequence {seq_id} not found.")
|
||||
self.seqs.remove(seq)
|
||||
self.is_single_seq = len(self.seqs) == 1
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
return all(seq.is_finished() for seq in self.seqs)
|
||||
@ -775,9 +794,10 @@ class SequenceGroupMetadata:
|
||||
# TODO: We should maintain this states out of the sequence group.
|
||||
self.num_speculative_tokens = None
|
||||
|
||||
if self._token_chunk_size is None:
|
||||
if seq_data is not None and self._token_chunk_size is None:
|
||||
if is_prompt:
|
||||
self._token_chunk_size = list(seq_data.values())[0].get_len()
|
||||
self._token_chunk_size = next(iter(
|
||||
seq_data.values())).get_len()
|
||||
else:
|
||||
self._token_chunk_size = 1
|
||||
|
||||
|
||||
@ -261,6 +261,44 @@ class LRUCache(Generic[T]):
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
class PyObjectCache:
|
||||
"""Used to cache python objects to avoid object allocations
|
||||
across scheduler iterations.
|
||||
"""
|
||||
|
||||
def __init__(self, obj_builder):
|
||||
self._obj_builder = obj_builder
|
||||
self._index = 0
|
||||
|
||||
self._obj_cache = []
|
||||
for _ in range(128):
|
||||
self._obj_cache.append(self._obj_builder())
|
||||
|
||||
def _grow_cache(self):
|
||||
# Double the size of the cache
|
||||
num_objs = len(self._obj_cache)
|
||||
for _ in range(num_objs):
|
||||
self._obj_cache.append(self._obj_builder())
|
||||
|
||||
def get_object(self):
|
||||
"""Returns a pre-allocated cached object. If there is not enough
|
||||
objects, then the cache size will double.
|
||||
"""
|
||||
if self._index >= len(self._obj_cache):
|
||||
self._grow_cache()
|
||||
assert self._index < len(self._obj_cache)
|
||||
|
||||
obj = self._obj_cache[self._index]
|
||||
self._index += 1
|
||||
|
||||
return obj
|
||||
|
||||
def reset(self):
|
||||
"""Makes all cached-objects available for the next scheduler iteration.
|
||||
"""
|
||||
self._index = 0
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import dataclasses
|
||||
import gc
|
||||
import itertools
|
||||
import time
|
||||
import warnings
|
||||
import weakref
|
||||
@ -35,7 +36,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||
@ -50,8 +51,8 @@ from vllm.prompt_adapter.worker_manager import (
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.utils import (CudaMemoryProfiler, async_tensor_h2d, flatten_2d_lists,
|
||||
get_kv_cache_torch_dtype, is_hip,
|
||||
from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d,
|
||||
flatten_2d_lists, get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
@ -178,6 +179,20 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
class InterDataForSeqGroup:
|
||||
"""Intermediate data for the current sequence group."""
|
||||
|
||||
def simple_reinit(self):
|
||||
self.input_tokens[0].clear() # type: ignore
|
||||
self.input_positions[0].clear() # type: ignore
|
||||
self.seq_lens[0] = 0 # type: ignore
|
||||
self.orig_seq_lens[0] = 0 # type: ignore
|
||||
self.query_lens[0] = 0 # type: ignore
|
||||
self.context_lens[0] = 0 # type: ignore
|
||||
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
||||
self.lora_index_mapping.clear() # type: ignore
|
||||
self.lora_prompt_mapping.clear() # type: ignore
|
||||
self.lora_requests.clear() # type: ignore
|
||||
self.prompt_adapter_index_mapping.clear() # type: ignore
|
||||
self.prompt_adapter_prompt_mapping.clear() # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@ -220,35 +235,121 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
# Whether the prefix cache is hit (prefill only).
|
||||
prefix_cache_hit: bool = False,
|
||||
reinit: bool = False,
|
||||
reinit_use_defaults: bool = False,
|
||||
):
|
||||
if reinit:
|
||||
assert len(self.seq_ids) == len(seq_ids) # type: ignore
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
self.seq_ids[i] = seq_id # type: ignore
|
||||
else:
|
||||
self.seq_ids = seq_ids
|
||||
|
||||
self.request_id = request_id
|
||||
self.seq_ids = seq_ids
|
||||
self.is_prompt = is_prompt
|
||||
self.block_tables = block_tables
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.n_seqs = n_seqs
|
||||
self.input_tokens = input_tokens or []
|
||||
self.input_positions = input_positions or []
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
self.query_lens = query_lens or []
|
||||
self.context_lens = context_lens or []
|
||||
self.curr_sliding_window_blocks = curr_sliding_window_blocks or []
|
||||
|
||||
self.lora_index_mapping = lora_index_mapping or []
|
||||
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||||
self.lora_requests = lora_requests or set()
|
||||
if reinit:
|
||||
if len(self.seq_ids) == 1 and reinit_use_defaults:
|
||||
self.simple_reinit()
|
||||
else:
|
||||
if input_tokens:
|
||||
self.input_tokens = input_tokens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_tokens[seq_id].clear()
|
||||
|
||||
if input_positions:
|
||||
self.input_positions = input_positions
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_positions[seq_id].clear()
|
||||
|
||||
if seq_lens:
|
||||
self.seq_lens = seq_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.seq_lens[seq_id] = 0
|
||||
|
||||
if orig_seq_lens:
|
||||
self.orig_seq_lens = orig_seq_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.orig_seq_lens[seq_id] = 0
|
||||
|
||||
if query_lens:
|
||||
self.query_lens = query_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.query_lens[seq_id] = 0
|
||||
|
||||
if context_lens:
|
||||
self.context_lens = context_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.context_lens[seq_id] = 0
|
||||
|
||||
if curr_sliding_window_blocks:
|
||||
self.curr_sliding_window_blocks = \
|
||||
curr_sliding_window_blocks
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.curr_sliding_window_blocks[seq_id] = 0
|
||||
|
||||
if lora_index_mapping:
|
||||
self.lora_index_mapping = lora_index_mapping
|
||||
else:
|
||||
self.lora_index_mapping.clear()
|
||||
|
||||
if lora_prompt_mapping:
|
||||
self.lora_prompt_mapping = lora_prompt_mapping
|
||||
else:
|
||||
self.lora_prompt_mapping.clear()
|
||||
|
||||
if lora_requests:
|
||||
self.lora_requests = lora_requests
|
||||
else:
|
||||
self.lora_requests.clear()
|
||||
|
||||
if prompt_adapter_index_mapping:
|
||||
self.prompt_adapter_index_mapping = \
|
||||
prompt_adapter_index_mapping
|
||||
else:
|
||||
self.prompt_adapter_index_mapping.clear()
|
||||
|
||||
if prompt_adapter_prompt_mapping:
|
||||
self.prompt_adapter_prompt_mapping = \
|
||||
prompt_adapter_prompt_mapping
|
||||
else:
|
||||
self.prompt_adapter_prompt_mapping.clear()
|
||||
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.input_positions = input_positions or []
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
self.query_lens = query_lens or []
|
||||
self.context_lens = context_lens or []
|
||||
self.curr_sliding_window_blocks = \
|
||||
curr_sliding_window_blocks or []
|
||||
|
||||
self.lora_index_mapping = lora_index_mapping or []
|
||||
self.lora_prompt_mapping = lora_prompt_mapping or []
|
||||
self.lora_requests = lora_requests or set()
|
||||
|
||||
self.prompt_adapter_index_mapping = (
|
||||
prompt_adapter_index_mapping or [])
|
||||
self.prompt_adapter_prompt_mapping = (
|
||||
prompt_adapter_prompt_mapping or [])
|
||||
|
||||
self.prompt_adapter_index_mapping = (prompt_adapter_index_mapping
|
||||
or [])
|
||||
self.prompt_adapter_prompt_mapping = (prompt_adapter_prompt_mapping
|
||||
or [])
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
|
||||
self.multi_modal_inputs = multi_modal_inputs
|
||||
self.prefix_cache_hit = prefix_cache_hit
|
||||
|
||||
self.__post_init__()
|
||||
if not reinit:
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(self):
|
||||
self.n_seqs = len(self.seq_ids)
|
||||
@ -261,8 +362,36 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.context_lens = [0] * self.n_seqs
|
||||
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
||||
|
||||
self.lora_index_mapping = [[] for _ in range(self.n_seqs)]
|
||||
self.lora_prompt_mapping = [[] for _ in range(self.n_seqs)]
|
||||
self.lora_index_mapping = []
|
||||
self.lora_prompt_mapping = []
|
||||
|
||||
def gen_inter_data_builder(self, num_seqs: int):
|
||||
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
|
||||
request_id="",
|
||||
seq_ids=[0] * num_seqs,
|
||||
is_prompt=True,
|
||||
block_tables=None,
|
||||
computed_block_nums=[])
|
||||
|
||||
def init_cached_inter_data(self, *args, **kwargs):
|
||||
assert len(args) == 0
|
||||
assert "seq_ids" in kwargs
|
||||
seq_ids = kwargs["seq_ids"]
|
||||
num_seqs = len(seq_ids)
|
||||
|
||||
# The inter-data cache is per model_runner
|
||||
inter_data_cache = self.runner.inter_data_cache
|
||||
if num_seqs not in inter_data_cache:
|
||||
inter_data_cache[num_seqs] = PyObjectCache(
|
||||
self.gen_inter_data_builder(num_seqs))
|
||||
|
||||
obj = inter_data_cache[num_seqs].get_object()
|
||||
obj.__init__(*args, **kwargs)
|
||||
return obj
|
||||
|
||||
def reset_cached_inter_data(self):
|
||||
for cache in self.runner.inter_data_cache.values():
|
||||
cache.reset()
|
||||
|
||||
def __init__(self,
|
||||
runner: "GPUModelRunnerBase",
|
||||
@ -337,17 +466,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
# Compute tokens.
|
||||
if inter_data.is_prompt:
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
tokens = seq_data.get_token_ids()
|
||||
if context_len != 0 or seq_len < len(tokens):
|
||||
tokens = tokens[context_len:seq_len]
|
||||
else:
|
||||
# Optimization. get_token_ids requires the entire copy of
|
||||
# tokens.
|
||||
tokens = [seq_data.get_last_token_id()]
|
||||
tokens = seq_data.get_last_token_id()
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx] = tokens
|
||||
inter_data.input_positions[seq_idx] = list(range(context_len, seq_len))
|
||||
|
||||
if isinstance(tokens, list):
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
else:
|
||||
inter_data.input_tokens[seq_idx].append(tokens)
|
||||
|
||||
if (seq_len - context_len) == 1:
|
||||
inter_data.input_positions[seq_idx].append(seq_len - 1)
|
||||
else:
|
||||
inter_data.input_positions[seq_idx].extend(
|
||||
range(context_len, seq_len))
|
||||
|
||||
inter_data.query_lens[
|
||||
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
|
||||
|
||||
@ -471,7 +612,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
"""Add a sequence group to the builder."""
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
seq_ids = seq_group_metadata.seq_data.keys()
|
||||
n_seqs = len(seq_ids)
|
||||
is_prompt = seq_group_metadata.is_prompt
|
||||
|
||||
@ -479,12 +620,15 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
assert n_seqs == 1
|
||||
self.decode_only = False
|
||||
|
||||
inter_data = self.InterDataForSeqGroup(
|
||||
inter_data = self.init_cached_inter_data(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
seq_ids=seq_ids,
|
||||
is_prompt=is_prompt,
|
||||
block_tables=seq_group_metadata.block_tables,
|
||||
computed_block_nums=seq_group_metadata.computed_block_nums)
|
||||
computed_block_nums=seq_group_metadata.computed_block_nums,
|
||||
reinit=True,
|
||||
reinit_use_defaults=True)
|
||||
|
||||
self.inter_data_list.append(inter_data)
|
||||
|
||||
for seq_idx in range(n_seqs):
|
||||
@ -504,18 +648,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
create on-device tensors.
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.input_tokens)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
input_tokens = []
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_tokens in inter_data.input_tokens:
|
||||
input_tokens.extend(cur_input_tokens)
|
||||
|
||||
if not input_tokens:
|
||||
# This may happen when all prefill requests hit
|
||||
# prefix caching and there is no decode request.
|
||||
return self.model_input_cls()
|
||||
input_positions = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.input_positions)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
|
||||
input_positions = []
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_positions in inter_data.input_positions:
|
||||
input_positions.extend(cur_input_positions)
|
||||
|
||||
seq_lens = []
|
||||
max_decode_seq_len = 0
|
||||
for inter_data in self.inter_data_list:
|
||||
@ -523,8 +670,10 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
if not inter_data.is_prompt:
|
||||
max_decode_seq_len = max(max_decode_seq_len,
|
||||
max(inter_data.seq_lens))
|
||||
query_lens = flatten_2d_lists(
|
||||
[inter_data.query_lens for inter_data in self.inter_data_list])
|
||||
query_lens = []
|
||||
for inter_data in self.inter_data_list:
|
||||
query_lens.extend(inter_data.query_lens)
|
||||
|
||||
# Mapping from request IDs to sequence IDs. Used for Jamba models
|
||||
# that manages the cache by itself.
|
||||
request_ids_to_seq_ids = {
|
||||
@ -547,8 +696,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
batch_size = graph_batch_size
|
||||
|
||||
# Tokens and positions.
|
||||
input_tokens.extend([0] * cuda_graph_pad_size)
|
||||
input_positions.extend([0] * cuda_graph_pad_size)
|
||||
if cuda_graph_pad_size:
|
||||
input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||
input_positions.extend(itertools.repeat(0, cuda_graph_pad_size))
|
||||
assert self.runner.device is not None
|
||||
input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long,
|
||||
self.runner.device,
|
||||
@ -558,7 +708,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.runner.pin_memory)
|
||||
|
||||
# Sequence and query lengths.
|
||||
seq_lens.extend([1] * cuda_graph_pad_size)
|
||||
if cuda_graph_pad_size:
|
||||
seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size))
|
||||
|
||||
# Attention metadata.
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
@ -574,11 +725,14 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
flatten_2d_lists(inter_data.lora_index_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
lora_index_mapping.extend([0] * cuda_graph_pad_size)
|
||||
if cuda_graph_pad_size:
|
||||
lora_index_mapping.extend(
|
||||
itertools.repeat(0, cuda_graph_pad_size))
|
||||
lora_prompt_mapping = flatten_2d_lists([
|
||||
flatten_2d_lists(inter_data.lora_prompt_mapping)
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
|
||||
lora_mapping = LoRAMapping(
|
||||
**dict(index_mapping=lora_index_mapping,
|
||||
prompt_mapping=lora_prompt_mapping,
|
||||
@ -595,7 +749,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
inter_data.prompt_adapter_index_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
])
|
||||
prompt_adapter_index_mapping.extend([0] * cuda_graph_pad_size)
|
||||
if cuda_graph_pad_size:
|
||||
prompt_adapter_index_mapping.extend(
|
||||
itertools.repeat(0, cuda_graph_pad_size))
|
||||
prompt_adapter_prompt_mapping = flatten_2d_lists([
|
||||
inter_data.prompt_adapter_prompt_mapping
|
||||
for inter_data in self.inter_data_list
|
||||
@ -717,6 +873,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
set_cpu_offload_max_bytes(
|
||||
int(self.cache_config.cpu_offload_gb * 1024**3))
|
||||
|
||||
# Used to cache python objects
|
||||
self.inter_data_cache: Dict[int, PyObjectCache] = {}
|
||||
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||
SamplingMetadataCache()
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with CudaMemoryProfiler() as m:
|
||||
@ -843,6 +1004,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
builder = self._builder_cls(weakref.proxy(self), finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
builder.add_seq_group(seq_group_metadata)
|
||||
|
||||
builder.reset_cached_inter_data()
|
||||
|
||||
return builder.build() # type: ignore
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -1276,7 +1440,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, model_input.seq_lens,
|
||||
model_input.query_lens, self.device, self.pin_memory,
|
||||
generators)
|
||||
generators, self.sampling_metadata_cache)
|
||||
else:
|
||||
sampling_metadata = None
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user