[Misc] Minor refactoring for scheduler (#20299)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-07-01 07:55:32 -07:00 committed by GitHub
parent ecad851cbd
commit 0e96cc9b7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -580,6 +580,13 @@ class Scheduler(SchedulerInterface):
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
self._update_after_schedule(scheduler_output)
return scheduler_output
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
@ -589,11 +596,15 @@ class Scheduler(SchedulerInterface):
# scheduling step.
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
# computed tokens will be adjusted in update_from_output.
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items():
self.requests[req_id].num_computed_tokens += num_scheduled_token
request = self.requests[req_id]
request.num_computed_tokens += num_scheduled_token
# Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
# it will also affect the scheduler output.
self.finished_req_ids = set()
return scheduler_output
def _make_cached_request_data(
self,
@ -763,19 +774,10 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(
request, input_id)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
if request.has_encoder_inputs:
self._free_encoder_inputs(request)
stopped = False
new_logprobs = None
@ -891,6 +893,25 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs
def _free_encoder_inputs(self, request: Request) -> None:
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
# OPTIMIZATION: Avoid list(set) if the set is empty.
if not cached_encoder_input_ids:
return
# Here, we use list(set) to avoid modifying the set while iterating
# over it.
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free_encoder_input(
request, input_id)
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)