mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 12:44:27 +08:00
Fix can_swap_in
This commit is contained in:
parent
a2a9869cb7
commit
eb52db1bea
@ -70,13 +70,14 @@ class BlockSpaceManager:
|
||||
self.block_tables: Dict[int, BlockTable] = {}
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> bool:
|
||||
# NOTE: Here we assume that all sequences in the group have the same prompt.
|
||||
seq = seq_group.seqs[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
return num_required_blocks <= num_free_gpu_blocks
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# Here, we assume that all sequences in the group have the same prompt.
|
||||
# NOTE: Here we assume that all sequences in the group have the same prompt.
|
||||
seq = seq_group.seqs[0]
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
@ -124,10 +125,10 @@ class BlockSpaceManager:
|
||||
self.gpu_allocator.free(last_block)
|
||||
return last_block.block_number, new_block.block_number
|
||||
|
||||
def fork(self, src_seq: Sequence, child_seq: Sequence) -> None:
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
# NOTE: fork does not allocate a new physical block.
|
||||
# Thus, it is always safe from OOM.
|
||||
src_block_table = self.block_tables[src_seq.seq_id]
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.copy()
|
||||
for block in src_block_table:
|
||||
block.ref_count += 1
|
||||
@ -146,7 +147,12 @@ class BlockSpaceManager:
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
return len(blocks) <= self.gpu_allocator.get_num_free_blocks()
|
||||
num_running_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# NOTE: Conservatively, we assume that every sequence will allocate
|
||||
# at least one free block right after the swap-in.
|
||||
# NOTE: This should match the logic in can_append().
|
||||
return len(blocks) + num_running_seqs <= num_free_blocks
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# src_block_number -> dst_block_number
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user