mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:45:23 +08:00
[Bugfix] V1 Fix the cursor leakage issue during request scheduling. (#21173)
Signed-off-by: CLFutureX <775523362@qq.com>
This commit is contained in:
parent
bdcb42e45d
commit
2dffac464c
@ -1307,13 +1307,18 @@ def create_requests_with_priority(
|
|||||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
stop_token_ids: Optional[list[int]] = None,
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
prompt_logprobs: Optional[int] = None):
|
prompt_logprobs: Optional[int] = None,
|
||||||
|
request_ids: Optional[list[str]] = None):
|
||||||
"""Create requests with specified priorities and arrival times."""
|
"""Create requests with specified priorities and arrival times."""
|
||||||
assert len(priorities) == num_requests
|
assert len(priorities) == num_requests
|
||||||
if arrival_times is not None:
|
if arrival_times is not None:
|
||||||
assert len(arrival_times) == num_requests
|
assert len(arrival_times) == num_requests
|
||||||
else:
|
else:
|
||||||
arrival_times = [float(i) for i in range(num_requests)]
|
arrival_times = [float(i) for i in range(num_requests)]
|
||||||
|
if request_ids is not None:
|
||||||
|
assert len(request_ids) == num_requests
|
||||||
|
else:
|
||||||
|
request_ids = [f"{i}" for i in range(num_requests)]
|
||||||
|
|
||||||
sampling_params = SamplingParams(ignore_eos=False,
|
sampling_params = SamplingParams(ignore_eos=False,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@ -1328,7 +1333,7 @@ def create_requests_with_priority(
|
|||||||
mm_position = None
|
mm_position = None
|
||||||
mm_inputs = None
|
mm_inputs = None
|
||||||
request = Request(
|
request = Request(
|
||||||
request_id=f"{i}",
|
request_id=request_ids[i],
|
||||||
prompt_token_ids=[i] * num_tokens,
|
prompt_token_ids=[i] * num_tokens,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
pooling_params=None,
|
pooling_params=None,
|
||||||
@ -1829,3 +1834,91 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
|||||||
assert len(output.scheduled_new_reqs) == 0
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_scheduling_preemption_victim_iterator_order():
|
||||||
|
"""Test that the scheduling order is maintained after
|
||||||
|
preempting lower-priority requests."""
|
||||||
|
scheduler = create_scheduler_with_priority(
|
||||||
|
max_num_batched_tokens=200,
|
||||||
|
num_blocks=9,
|
||||||
|
)
|
||||||
|
# Add three priority requests first.
|
||||||
|
priority_requests = create_requests_with_priority(
|
||||||
|
num_requests=3,
|
||||||
|
priorities=[3, 4, 5],
|
||||||
|
arrival_times=[1.0, 2.0, 3.0],
|
||||||
|
num_tokens=15,
|
||||||
|
request_ids=["1", "2", "3"],
|
||||||
|
)
|
||||||
|
|
||||||
|
for request in priority_requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
# After scheduling, transfer from the waiting queue to the running queue.
|
||||||
|
# At this time, 3 blocks have been allocated, and 5 available blocks remain.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in priority_requests],
|
||||||
|
req_id_to_index={
|
||||||
|
req.request_id: i
|
||||||
|
for i, req in enumerate(priority_requests)
|
||||||
|
},
|
||||||
|
sampled_token_ids=[[15] for _ in priority_requests],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_output)
|
||||||
|
|
||||||
|
# Add tow high priority requests.
|
||||||
|
high_priority_requests = create_requests_with_priority(
|
||||||
|
num_requests=2,
|
||||||
|
priorities=[1, 2],
|
||||||
|
arrival_times=[4.0, 5.0],
|
||||||
|
num_tokens=16,
|
||||||
|
request_ids=["4", "5"],
|
||||||
|
)
|
||||||
|
for request in high_priority_requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# After scheduling, transfer the two high-priority requests from
|
||||||
|
# the waiting queue to the running queue.
|
||||||
|
# the IDs of the requests in the running queue are: 1, 2, 3, 4, 5.
|
||||||
|
# At this time, 3+2 blocks have been allocated,
|
||||||
|
# and 3 available blocks remain.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
merge_requests = priority_requests + high_priority_requests
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in merge_requests],
|
||||||
|
req_id_to_index={
|
||||||
|
req.request_id: i
|
||||||
|
for i, req in enumerate(merge_requests)
|
||||||
|
},
|
||||||
|
sampled_token_ids=[[1] for _ in merge_requests],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=[],
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_output)
|
||||||
|
|
||||||
|
# At this time, the request with the lowest priority
|
||||||
|
# (request.id = 2) will be preempted, freeing up 2 blocks,
|
||||||
|
# which exactly meets the resource allocation requirements
|
||||||
|
# for request.id = 4 and request.id = 5.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
# Should schedule the new request without preemption.
|
||||||
|
assert len(scheduler.running) == 4 #
|
||||||
|
assert len(scheduler.waiting) == 1 #
|
||||||
|
|
||||||
|
running_priorities = [req.priority for req in scheduler.running]
|
||||||
|
running_req_ids = [req.request_id for req in scheduler.running]
|
||||||
|
|
||||||
|
assert running_priorities == [3, 4, 1, 2]
|
||||||
|
assert running_req_ids == ["1", "2", "4", "5"]
|
||||||
|
assert scheduler.waiting.peek_request().priority == 5
|
||||||
|
|||||||
@ -257,7 +257,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.running,
|
self.running,
|
||||||
key=lambda r: (r.priority, r.arrival_time),
|
key=lambda r: (r.priority, r.arrival_time),
|
||||||
)
|
)
|
||||||
self.running.remove(preempted_req)
|
preempted_index = self.running.index(preempted_req)
|
||||||
|
if preempted_index <= req_index:
|
||||||
|
req_index -= 1
|
||||||
|
scheduled_running_reqs.remove(preempted_req)
|
||||||
|
self.running.pop(preempted_index)
|
||||||
else:
|
else:
|
||||||
preempted_req = self.running.pop()
|
preempted_req = self.running.pop()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user