mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 13:37:56 +08:00
[Misc] Minor refactoring for scheduler (#20299)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
ecad851cbd
commit
0e96cc9b7e
@ -580,6 +580,13 @@ class Scheduler(SchedulerInterface):
|
|||||||
batch = KVEventBatch(ts=time.time(), events=events)
|
batch = KVEventBatch(ts=time.time(), events=events)
|
||||||
self.kv_event_publisher.publish(batch)
|
self.kv_event_publisher.publish(batch)
|
||||||
|
|
||||||
|
self._update_after_schedule(scheduler_output)
|
||||||
|
return scheduler_output
|
||||||
|
|
||||||
|
def _update_after_schedule(
|
||||||
|
self,
|
||||||
|
scheduler_output: SchedulerOutput,
|
||||||
|
) -> None:
|
||||||
# Advance the number of computed tokens for the request AFTER
|
# Advance the number of computed tokens for the request AFTER
|
||||||
# the request is scheduled.
|
# the request is scheduled.
|
||||||
# 1. The scheduler_output of the current step has to include the
|
# 1. The scheduler_output of the current step has to include the
|
||||||
@ -589,11 +596,15 @@ class Scheduler(SchedulerInterface):
|
|||||||
# scheduling step.
|
# scheduling step.
|
||||||
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
|
# 3. If some tokens (e.g. spec tokens) are rejected later, the number of
|
||||||
# computed tokens will be adjusted in update_from_output.
|
# computed tokens will be adjusted in update_from_output.
|
||||||
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
||||||
self.requests[req_id].num_computed_tokens += num_scheduled_token
|
request = self.requests[req_id]
|
||||||
|
request.num_computed_tokens += num_scheduled_token
|
||||||
|
|
||||||
|
# Clear the finished request IDs.
|
||||||
|
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
|
||||||
|
# it will also affect the scheduler output.
|
||||||
self.finished_req_ids = set()
|
self.finished_req_ids = set()
|
||||||
return scheduler_output
|
|
||||||
|
|
||||||
def _make_cached_request_data(
|
def _make_cached_request_data(
|
||||||
self,
|
self,
|
||||||
@ -763,19 +774,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_draft_tokens=len(scheduled_spec_token_ids),
|
num_draft_tokens=len(scheduled_spec_token_ids),
|
||||||
num_accepted_tokens=len(generated_token_ids) - 1)
|
num_accepted_tokens=len(generated_token_ids) - 1)
|
||||||
|
|
||||||
cached_encoder_input_ids = (
|
# NOTE(woosuk): This has to be executed after updating
|
||||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
# `request.num_computed_tokens`.
|
||||||
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
if request.has_encoder_inputs:
|
||||||
if cached_encoder_input_ids:
|
self._free_encoder_inputs(request)
|
||||||
for input_id in list(cached_encoder_input_ids):
|
|
||||||
mm_positions = request.mm_positions[input_id]
|
|
||||||
start_pos = mm_positions.offset
|
|
||||||
num_tokens = mm_positions.length
|
|
||||||
if start_pos + num_tokens <= request.num_computed_tokens:
|
|
||||||
# The encoder output is already processed and stored
|
|
||||||
# in the decoder's KV cache.
|
|
||||||
self.encoder_cache_manager.free_encoder_input(
|
|
||||||
request, input_id)
|
|
||||||
|
|
||||||
stopped = False
|
stopped = False
|
||||||
new_logprobs = None
|
new_logprobs = None
|
||||||
@ -891,6 +893,25 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
|
def _free_encoder_inputs(self, request: Request) -> None:
|
||||||
|
cached_encoder_input_ids = (
|
||||||
|
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||||
|
# OPTIMIZATION: Avoid list(set) if the set is empty.
|
||||||
|
if not cached_encoder_input_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Here, we use list(set) to avoid modifying the set while iterating
|
||||||
|
# over it.
|
||||||
|
for input_id in list(cached_encoder_input_ids):
|
||||||
|
mm_positions = request.mm_positions[input_id]
|
||||||
|
start_pos = mm_positions.offset
|
||||||
|
num_tokens = mm_positions.length
|
||||||
|
if start_pos + num_tokens <= request.num_computed_tokens:
|
||||||
|
# The encoder output is already processed and stored
|
||||||
|
# in the decoder's KV cache.
|
||||||
|
self.encoder_cache_manager.free_encoder_input(
|
||||||
|
request, input_id)
|
||||||
|
|
||||||
def get_request_counts(self) -> tuple[int, int]:
|
def get_request_counts(self) -> tuple[int, int]:
|
||||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||||
return len(self.running), len(self.waiting)
|
return len(self.running), len(self.waiting)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user