mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 15:24:28 +08:00
[Misc] Minor refactoring for scheduler (#20299)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
ecad851cbd
commit
0e96cc9b7e
@ -580,6 +580,13 @@ class Scheduler(SchedulerInterface):
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def _update_after_schedule(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> None:
|
||||
# Advance the number of computed tokens for the request AFTER
|
||||
# the request is scheduled.
|
||||
# 1. The scheduler_output of the current step has to include the
|
||||
@ -589,11 +596,15 @@ class Scheduler(SchedulerInterface):
|
||||
# scheduling step.
|
||||
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
|
||||
# computed tokens will be adjusted in update_from_output.
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
||||
self.requests[req_id].num_computed_tokens += num_scheduled_token
|
||||
request = self.requests[req_id]
|
||||
request.num_computed_tokens += num_scheduled_token
|
||||
|
||||
# Clear the finished request IDs.
|
||||
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
|
||||
# it will also affect the scheduler output.
|
||||
self.finished_req_ids = set()
|
||||
return scheduler_output
|
||||
|
||||
def _make_cached_request_data(
|
||||
self,
|
||||
@ -763,19 +774,10 @@ class Scheduler(SchedulerInterface):
|
||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||
|
||||
cached_encoder_input_ids = (
|
||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
||||
if cached_encoder_input_ids:
|
||||
for input_id in list(cached_encoder_input_ids):
|
||||
mm_positions = request.mm_positions[input_id]
|
||||
start_pos = mm_positions.offset
|
||||
num_tokens = mm_positions.length
|
||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
# NOTE(woosuk): This has to be executed after updating
|
||||
# `request.num_computed_tokens`.
|
||||
if request.has_encoder_inputs:
|
||||
self._free_encoder_inputs(request)
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
@ -891,6 +893,25 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
def _free_encoder_inputs(self, request: Request) -> None:
|
||||
cached_encoder_input_ids = (
|
||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
||||
if not cached_encoder_input_ids:
|
||||
return
|
||||
|
||||
# Here, we use list(set) to avoid modifying the set while iterating
|
||||
# over it.
|
||||
for input_id in list(cached_encoder_input_ids):
|
||||
mm_positions = request.mm_positions[input_id]
|
||||
start_pos = mm_positions.offset
|
||||
num_tokens = mm_positions.length
|
||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||
# The encoder output is already processed and stored
|
||||
# in the decoder's KV cache.
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
return len(self.running), len(self.waiting)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user