[Core] LoRA: V1 Scheduler optimization (#15422)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-03-25 15:50:09 -07:00 committed by GitHub
parent ac3cd6e83c
commit a5cfbab3c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -239,16 +239,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = new_encoder_budget
# Record the LoRAs in scheduled_running_reqs
requested_loras: set[int] = set()
scheduled_loras: set[int] = set()
if self.lora_config:
requested_loras = set(
scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
waiting_for_fsm: deque[Request] = deque()
skipped_waiting_requests: deque[Request] = deque()
# Next, schedule the WAITING requests.
if not preempted_reqs:
@ -258,31 +258,30 @@ class Scheduler(SchedulerInterface):
request = self.waiting[0]
if request.status == RequestStatus.WAITING_FOR_FSM:
# Waiting request skipping logic
is_skipped = False
# Skip request if the structured output request is still waiting
# for FSM.
if (not is_skipped
and request.status == RequestStatus.WAITING_FOR_FSM):
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
is_skipped = (not structured_output_req
or not structured_output_req.grammar)
if not is_skipped:
request.status = RequestStatus.WAITING
else:
waiting_structured_output_req = self.waiting.popleft()
waiting_for_fsm.appendleft(
waiting_structured_output_req)
continue
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request:
# Skip request if max_loras can't be honored.
if (not is_skipped and self.lora_config
and request.lora_request):
req_lora_id = request.lora_request.lora_int_id
if len(requested_loras) == self.lora_config.max_loras and (
req_lora_id not in requested_loras):
# Cannot schedule.
# TODO (varun): This means all the other requests in
# the WAITING queue will be blocked by this request,
# even if,
# 1. these other requests do not use LoRA, or,
# 2. these other requests use the already requested
# LoRAs.
# This is too conservative and could be optimized.
break
is_skipped = (len(scheduled_loras)
== self.lora_config.max_loras
and (req_lora_id not in scheduled_loras))
if is_skipped:
skipped_waiting_requests.appendleft(request)
self.waiting.popleft()
continue
# Get already-cached tokens.
computed_blocks, num_computed_tokens = \
@ -344,7 +343,7 @@ class Scheduler(SchedulerInterface):
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
requested_loras.add(request.lora_request.lora_int_id)
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks
]
@ -363,8 +362,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if waiting_for_fsm:
self.waiting.extendleft(waiting_for_fsm)
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())