From 90eb28ca218e14369f022716399d048b8aaf7f51 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 13 Mar 2025 16:11:07 -0400 Subject: [PATCH] [V1][Scheduler] Use dict for running queue This is just a random idea, still need to benchmark Potential advantages for large batch sizes: - Don't need to copy entire list every iteration - O(1) removal of aborted requests Signed-off-by: Nick Hill --- vllm/v1/core/scheduler.py | 45 ++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index d498891f476e7..c264678d76782 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -66,7 +66,7 @@ class Scheduler: self.requests: dict[str, Request] = {} # Priority queues for requests. self.waiting: deque[Request] = deque() - self.running: list[Request] = [] + self.running: dict[Request, None] = {} # The requests that have been scheduled and are being executed # by the executor. self.scheduled_req_ids: set[str] = set() @@ -140,12 +140,12 @@ class Scheduler: scheduled_timestamp = time.monotonic() # First, schedule the RUNNING requests. - req_index = 0 - while req_index < len(self.running) and token_budget > 0: - request = self.running[req_index] + running_count = len(self.running) + for req_index, request in enumerate(self.running): + if token_budget <= 0 or req_index == running_count: + break if request.request_id in self.scheduled_req_ids: # This request has already been scheduled. - req_index += 1 continue num_new_tokens = (request.num_tokens_with_spec - @@ -165,7 +165,6 @@ class Scheduler: # NOTE(woosuk): Here, by doing `continue` instead of `break`, # we do not strictly follow the FCFS scheduling policy and # allow the lower-priority requests to be scheduled. - req_index += 1 continue while True: @@ -174,7 +173,8 @@ class Scheduler: if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. - preempted_req = self.running.pop() + preempted_req = next(reversed(self.running)) + running_count += 1 self.kv_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 @@ -182,7 +182,7 @@ class Scheduler: self.waiting.appendleft(preempted_req) preempted_reqs.append(preempted_req) - if preempted_req == request: + if req_index == running_count: # No more request to preempt. can_schedule = False break @@ -208,7 +208,6 @@ class Scheduler: ] num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens - req_index += 1 # Speculative decode related. if request.spec_token_ids: @@ -230,6 +229,10 @@ class Scheduler: self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Remove preempted requests from the running queue. + while len(self.running) > running_count: + self.running.popitem() + # Record the LoRAs in scheduled_running_reqs requested_loras: set[int] = set() if self.lora_config: @@ -255,9 +258,8 @@ class Scheduler: if structured_output_req and structured_output_req.grammar: request.status = RequestStatus.WAITING else: - waiting_structured_output_req = self.waiting.popleft() - waiting_for_fsm.appendleft( - waiting_structured_output_req) + self.waiting.popleft() + waiting_for_fsm.appendleft(request) continue # Check that adding the request still respects the max_loras @@ -316,9 +318,8 @@ class Scheduler: self.waiting.popleft() if request.use_structured_output: structured_output_request_ids[ - request.request_id] = req_index - req_index += 1 - self.running.append(request) + request.request_id] = running_count + self.running[request] = None self.scheduled_req_ids.add(request.request_id) self.request_scheduled(request, scheduled_timestamp) if request.status == RequestStatus.WAITING: @@ -367,7 +368,7 @@ class Scheduler: # This can be potentially used for cascade attention. num_common_prefix_blocks = 0 if self.running: - any_request = self.running[0] + any_request = next(iter(self.running)) num_common_prefix_blocks = ( self.kv_cache_manager.get_num_common_prefix_blocks( any_request, len(self.running))) @@ -531,7 +532,7 @@ class Scheduler: prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens - new_running: list[Request] = [] + stopped_requests: list[Request] = [] outputs: list[EngineCoreOutput] = [] # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below @@ -542,7 +543,6 @@ class Scheduler: num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: # The request was not scheduled in this step. - new_running.append(request) continue req_index = model_runner_output.req_id_to_index[req_id] @@ -601,6 +601,7 @@ class Scheduler: stopped = self._check_stop(request) if stopped: self._free_request(request) + stopped_requests.append(request) break # Extract sample logprobs if needed. @@ -635,10 +636,10 @@ class Scheduler: events=request.take_events())) self.scheduled_req_ids.remove(request.request_id) - if not stopped: - new_running.append(request) - self.running = new_running + for stopped_request in stopped_requests: + del self.running[stopped_request] + return EngineCoreOutputs( outputs=outputs, scheduler_stats=self.make_stats(), @@ -691,7 +692,7 @@ class Scheduler: continue if request.status == RequestStatus.RUNNING: - self.running.remove(request) + del self.running[request] self.scheduled_req_ids.discard(request.request_id) else: self.waiting.remove(request)