From da0706721595bf892d7cb917cc02229775c967b4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 12 Mar 2025 22:04:57 -0700 Subject: [PATCH] Add common states Signed-off-by: Woosuk Kwon --- vllm/v1/core/sched/common.py | 50 +++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 57 ++++++--------------------------- 2 files changed, 60 insertions(+), 47 deletions(-) create mode 100644 vllm/v1/core/sched/common.py diff --git a/vllm/v1/core/sched/common.py b/vllm/v1/core/sched/common.py new file mode 100644 index 0000000000000..86df8aa362dc1 --- /dev/null +++ b/vllm/v1/core/sched/common.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.core.sched.output import CachedRequestData +from vllm.v1.request import Request + + +class CommonSchedulerStates: + + def __init__(self): + # The request IDs that are finished in between the previous and the + # current steps. This is used to notify the workers about the finished + # requests so that they can free the cached states for those requests. + # This is flushed at the end of each scheduling step. + self.finished_req_ids: set[str] = set() + + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating + # them at each scheduling step. + # Request id -> CachedRequestData + self._cached_reqs_data: dict[str, CachedRequestData] = {} + + def make_cached_request_data( + self, + request: Request, + num_scheduled_tokens: int, + num_scheduled_spec_tokens: int, + new_block_ids: list[int], + resumed_from_preemption: bool, + ) -> CachedRequestData: + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating + # them at each scheduling step. + num_computed_tokens = request.num_computed_tokens + num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens + new_token_ids = request.all_token_ids[ + num_computed_tokens:num_computed_tokens + num_regular_tokens] + req_data = self._cached_reqs_data.get(request.request_id) + if req_data is not None: + req_data.resumed_from_preemption = resumed_from_preemption + req_data.new_token_ids = new_token_ids + req_data.new_block_ids = new_block_ids + req_data.num_computed_tokens = num_computed_tokens + else: + req_data = CachedRequestData.from_request(request, + resumed_from_preemption, + new_token_ids, + new_block_ids) + self._cached_reqs_data[request.request_id] = req_data + return req_data + + def free_request(self, request: Request) -> None: + self._cached_reqs_data.pop(request.request_id, None) + self.finished_req_ids.add(request.request_id) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 30e26f402d639..75e93eaf67f7a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -13,9 +13,9 @@ from vllm.logger import init_logger from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.sched.common import CommonSchedulerStates from vllm.v1.core.sched.interface import SchedulerInterface -from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, - SchedulerOutput) +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.utils import check_stop from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) @@ -73,16 +73,8 @@ class Scheduler(SchedulerInterface): # by the executor. self.scheduled_req_ids: set[str] = set() - # The request IDs that are finished in between the previous and the - # current steps. This is used to notify the workers about the finished - # requests so that they can free the cached states for those requests. - # This is flushed at the end of each scheduling step. - self.finished_req_ids: set[str] = set() - - # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating - # them at each scheduling step. - # Request id -> CachedRequestData - self._cached_reqs_data: dict[str, CachedRequestData] = {} + # Misc states for the scheduler. + self.states = CommonSchedulerStates() # Encoder-related. # Calculate encoder cache size if applicable @@ -386,7 +378,7 @@ class Scheduler(SchedulerInterface): for req in scheduled_new_reqs ] resumed_reqs_data = [ - self._make_cached_request_data( + self.states.make_cached_request_data( req, num_scheduled_tokens[req.request_id], len(scheduled_spec_decode_tokens.get(req.request_id, ())), @@ -395,7 +387,7 @@ class Scheduler(SchedulerInterface): ) for req in scheduled_resumed_reqs ] running_reqs_data = [ - self._make_cached_request_data( + self.states.make_cached_request_data( req, num_scheduled_tokens[req.request_id], len(scheduled_spec_decode_tokens.get(req.request_id, ())), @@ -415,43 +407,15 @@ class Scheduler(SchedulerInterface): # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between # the previous and the current steps. - finished_req_ids=self.finished_req_ids, + finished_req_ids=self.states.finished_req_ids, free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) - self.finished_req_ids = set() + self.states.finished_req_ids = set() return scheduler_output - def _make_cached_request_data( - self, - request: Request, - num_scheduled_tokens: int, - num_scheduled_spec_tokens: int, - new_block_ids: list[int], - resumed_from_preemption: bool, - ) -> CachedRequestData: - # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating - # them at each scheduling step. - num_computed_tokens = request.num_computed_tokens - num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens - new_token_ids = request.all_token_ids[ - num_computed_tokens:num_computed_tokens + num_regular_tokens] - req_data = self._cached_reqs_data.get(request.request_id) - if req_data is not None: - req_data.resumed_from_preemption = resumed_from_preemption - req_data.new_token_ids = new_token_ids - req_data.new_block_ids = new_block_ids - req_data.num_computed_tokens = num_computed_tokens - else: - req_data = CachedRequestData.from_request(request, - resumed_from_preemption, - new_token_ids, - new_block_ids) - self._cached_reqs_data[request.request_id] = req_data - return req_data - def _try_schedule_encoder_inputs( self, request: Request, @@ -688,15 +652,14 @@ class Scheduler(SchedulerInterface): self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) self.encoder_cache_manager.free(request) - self._cached_reqs_data.pop(request.request_id, None) + self.states.free_request(request) del self.requests[request.request_id] - self.finished_req_ids.add(request.request_id) def get_num_unfinished_requests(self) -> int: return len(self.waiting) + len(self.running) def has_finished_requests(self) -> bool: - return len(self.finished_req_ids) > 0 + return len(self.states.finished_req_ids) > 0 def get_num_unscheduled_requests(self) -> int: """Number of requests that are not being processed by the executor."""