mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
[Core][KVConnector] Propagate all tokens on resumed preemptions (#24926)
Signed-off-by: Qier Li <kevin44036@gmail.com> Co-authored-by: Qier Li <qier@fb.com>
This commit is contained in:
parent
43ab8cfaa5
commit
d17f0fbf30
@ -1950,7 +1950,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
|
||||
def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
|
||||
"""Test that priority scheduling preempts lower priority requests
|
||||
when out of KV cache space."""
|
||||
# Create scheduler with very limited memory to force preemption
|
||||
@ -1959,6 +1959,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
max_num_batched_tokens=200,
|
||||
num_blocks=5, # Can hold 64 tokens (first block is null)
|
||||
block_size=16, # Standard block size
|
||||
use_kv_connector=True,
|
||||
)
|
||||
|
||||
# Create a request and schedule it
|
||||
@ -1970,12 +1971,13 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
starting_idx=0,
|
||||
)[0]
|
||||
scheduler.add_request(request_low)
|
||||
# 1st schedule
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Simulate model execution
|
||||
# Simulate model execution - 1st decode
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[request_low.request_id],
|
||||
req_id_to_index={request_low.request_id: 0},
|
||||
@ -1996,6 +1998,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
starting_idx=1,
|
||||
)[0]
|
||||
scheduler.add_request(request_high)
|
||||
# 2nd schedule
|
||||
output = scheduler.schedule()
|
||||
# KV cache should be full at this point
|
||||
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == 0
|
||||
@ -2004,7 +2007,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 2
|
||||
|
||||
# Simulate model execution
|
||||
# Simulate model execution - 2nd decode
|
||||
requests = [request_low, request_high]
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
@ -2017,7 +2020,7 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
)
|
||||
scheduler.update_from_output(output, model_output)
|
||||
|
||||
# Schedule again - this should trigger preemption
|
||||
# 3rd schedule - this should trigger preemption
|
||||
# req_low needs 32 tokens = 2 blocks
|
||||
# req_high needs 33 tokens = 3 blocks
|
||||
# so doesn't fit in 4 blocks.
|
||||
@ -2027,9 +2030,44 @@ def test_priority_scheduling_preemption_when_out_of_kv():
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert output.scheduled_cached_reqs.num_reqs == 1
|
||||
assert output.scheduled_cached_reqs.req_ids[0] == request_high.request_id
|
||||
assert scheduler.requests[request_low.request_id].status == RequestStatus.PREEMPTED
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Simulate model execution - 3rd decode
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={req.request_id: i for i, req in enumerate(requests)},
|
||||
sampled_token_ids=[[], [100]],
|
||||
# spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
# Finish the requests to make room for the preempted requests to resume
|
||||
scheduler.update_from_output(output, model_output)
|
||||
scheduler.finish_requests(request_high.request_id, RequestStatus.FINISHED_STOPPED)
|
||||
|
||||
# 4th Schedule - this should trigger the resumption
|
||||
output = scheduler.schedule()
|
||||
scheduled_cached_reqs = output.scheduled_cached_reqs
|
||||
resumed_from_preemption = scheduled_cached_reqs.resumed_from_preemption
|
||||
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert scheduled_cached_reqs.num_reqs == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Preempted request resumed in scheduled_cached_reqs
|
||||
assert len(resumed_from_preemption) == 1
|
||||
assert len(scheduled_cached_reqs.resumed_req_token_ids) == 1
|
||||
assert resumed_from_preemption[0]
|
||||
assert scheduled_cached_reqs.req_ids[0] == request_low.request_id
|
||||
assert scheduled_cached_reqs.resumed_req_token_ids[0] is not None
|
||||
# Resumed tokens include 30 prompt tokens and 2 decoded tokens
|
||||
assert len(scheduled_cached_reqs.resumed_req_token_ids[0]) == 32
|
||||
assert scheduled_cached_reqs.resumed_req_token_ids[0][31] == 100
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("enable_chunked_prefill", "is_encoder_decoder", "expect_enabled"),
|
||||
|
||||
@ -257,6 +257,7 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
req_ids=[req_id],
|
||||
resumed_from_preemption=[False],
|
||||
new_token_ids=[[]],
|
||||
resumed_req_token_ids=[None],
|
||||
new_block_ids=([[0]],),
|
||||
num_computed_tokens=[0],
|
||||
num_output_tokens=[0],
|
||||
|
||||
@ -98,6 +98,9 @@ class CachedRequestData:
|
||||
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
||||
# When PP is not used, new_token_ids will be empty.
|
||||
new_token_ids: list[list[int]]
|
||||
# If resumed_from_preemption is True, propogate the token ids to the
|
||||
# connector, otherwise will be empty.
|
||||
resumed_req_token_ids: list[list[int] | None]
|
||||
new_block_ids: list[tuple[list[int], ...] | None]
|
||||
num_computed_tokens: list[int]
|
||||
num_output_tokens: list[int]
|
||||
@ -112,6 +115,7 @@ class CachedRequestData:
|
||||
req_ids=[],
|
||||
resumed_from_preemption=[],
|
||||
new_token_ids=[],
|
||||
resumed_req_token_ids=[],
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=[],
|
||||
num_output_tokens=[],
|
||||
|
||||
@ -709,10 +709,15 @@ class Scheduler(SchedulerInterface):
|
||||
req_ids: list[str] = []
|
||||
new_token_ids: list[list[int]] = []
|
||||
new_block_ids: list[tuple[list[int], ...] | None] = []
|
||||
resumed_req_token_ids: list[list[int] | None] = []
|
||||
num_computed_tokens: list[int] = []
|
||||
num_output_tokens: list[int] = []
|
||||
|
||||
for req in itertools.chain(running_reqs, resumed_reqs):
|
||||
# Because resumed_reqs is usually empty, it is more efficient to do
|
||||
# in-place appending so that we don't need to allocate a new list.
|
||||
resumed_from_preemption = [False] * len(running_reqs)
|
||||
resumed_from_preemption += [True] * len(resumed_reqs)
|
||||
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
|
||||
req_id = req.request_id
|
||||
req_ids.append(req_id)
|
||||
num_tokens = num_scheduled_tokens[req_id] - len(
|
||||
@ -728,20 +733,23 @@ class Scheduler(SchedulerInterface):
|
||||
req.num_computed_tokens : req.num_computed_tokens + num_tokens
|
||||
]
|
||||
new_token_ids.append(token_ids)
|
||||
resumed_token_ids = None
|
||||
if resumed_from_preemption[idx]:
|
||||
resumed_token_ids = req.all_token_ids[
|
||||
: req.num_computed_tokens + num_tokens
|
||||
]
|
||||
resumed_req_token_ids.append(resumed_token_ids)
|
||||
new_block_ids.append(
|
||||
req_to_new_blocks[req_id].get_block_ids(allow_none=True)
|
||||
)
|
||||
num_computed_tokens.append(req.num_computed_tokens)
|
||||
num_output_tokens.append(req.num_output_tokens)
|
||||
# Because resumed_reqs is usually empty, it is more efficient to do
|
||||
# in-place appending so that we don't need to allocate a new list.
|
||||
resumed_from_preemption = [False] * len(running_reqs)
|
||||
resumed_from_preemption += [True] * len(resumed_reqs)
|
||||
|
||||
return CachedRequestData(
|
||||
req_ids=req_ids,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
new_token_ids=new_token_ids,
|
||||
resumed_req_token_ids=resumed_req_token_ids,
|
||||
new_block_ids=new_block_ids,
|
||||
num_computed_tokens=num_computed_tokens,
|
||||
num_output_tokens=num_output_tokens,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user