[core][misc] improve free_finished_seq_groups (#6865)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
youkaichao 2024-07-30 14:32:12 -07:00 committed by GitHub
parent d7a299edaa
commit 6ca8031e71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -313,6 +313,7 @@ class Scheduler:
# Sequence groups finished requests ids since last step iteration. # Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests # It lets the model know that any state associated with these requests
# can and must be released after the current step. # can and must be released after the current step.
# This is used to evict the finished requests from the Mamba cache.
self._finished_requests_ids: List[str] = list() self._finished_requests_ids: List[str] = list()
# Time at previous scheduling step # Time at previous scheduling step
self.prev_time = 0.0 self.prev_time = 0.0
@ -374,6 +375,7 @@ class Scheduler:
for aborted_group in aborted_groups: for aborted_group in aborted_groups:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(aborted_group) state_queue.remove(aborted_group)
# Remove the aborted request from the Mamba cache.
self._finished_requests_ids.append(aborted_group.request_id) self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs(): for seq in aborted_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
@ -1057,13 +1059,16 @@ class Scheduler:
self.block_manager.free(seq) self.block_manager.free(seq)
def free_finished_seq_groups(self) -> None: def free_finished_seq_groups(self) -> None:
for queue in [self.running, self.swapped, self.waiting]: remaining: Deque[SequenceGroup] = deque()
self._finished_requests_ids += [ for seq_group in self.running:
seq_group.request_id for seq_group in queue if seq_group.is_finished():
if seq_group.is_finished() # Add the finished requests to the finished requests list.
] # This list will be used to update the Mamba cache in the
self.running = deque(seq_group for seq_group in self.running # next step.
if not seq_group.is_finished()) self._finished_requests_ids.append(seq_group.request_id)
else:
remaining.append(seq_group)
self.running = remaining
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group) self.block_manager.allocate(seq_group)