[Optimization] Use Shared CachedRequestData Instance Across All Requests (#20232)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-06-30 09:07:50 -07:00 committed by GitHub
parent 2965c99c86
commit 2863befce3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 220 additions and 231 deletions

View File

@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
@ -198,7 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling # Test initial scheduling
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled. # Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items(): for req_id, num_tokens in output.num_scheduled_tokens.items():
@ -225,7 +225,7 @@ def test_schedule_multimodal_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items(): for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids) assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
@ -259,7 +259,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3 assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert scheduler.max_num_encoder_input_tokens == 1024 assert scheduler.max_num_encoder_input_tokens == 1024
@ -295,7 +295,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule() output = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output.scheduled_new_reqs) == 0 assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 2 assert output.scheduled_cached_reqs.num_reqs == 2
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 1 assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700 assert output.num_scheduled_tokens[requests[1].request_id] == 700
@ -319,7 +319,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1 assert len(output.scheduled_new_reqs) == 1
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# We want to only see the 400 text tokens at the start scheduled # We want to only see the 400 text tokens at the start scheduled
assert output.num_scheduled_tokens[requests[0].request_id] == 400 assert output.num_scheduled_tokens[requests[0].request_id] == 400
@ -342,7 +342,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule() output = scheduler.schedule()
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(output.scheduled_new_reqs) == 0 assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 1 assert output.scheduled_cached_reqs.num_reqs == 1
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 800 assert output.num_scheduled_tokens[requests[0].request_id] == 800
@ -379,7 +379,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3 assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0 assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0 assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400. # The first request is scheduled partially - 400.
@ -408,7 +408,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output1 = scheduler.schedule() output1 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0 assert len(output1.scheduled_new_reqs) == 0
assert len(output1.scheduled_cached_reqs) == 3 assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0 assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400 assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400 assert output1.num_scheduled_tokens[requests[1].request_id] == 400
@ -430,7 +430,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output2 = scheduler.schedule() output2 = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0 assert len(output2.scheduled_new_reqs) == 0
assert len(output2.scheduled_cached_reqs) == 3 assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0 assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1 assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1 assert output2.num_scheduled_tokens[requests[1].request_id] == 1
@ -449,23 +449,24 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 1, num_scheduled_tokens={
requests[1].request_id: 2 requests[0].request_id: 1,
}, requests[1].request_id: 2
total_num_scheduled_tokens=3, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=3,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [], scheduled_spec_decode_tokens={
requests[1].request_id: [10] requests[0].request_id: [],
}, requests[1].request_id: [10]
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
@ -501,23 +502,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 3, num_scheduled_tokens={
requests[1].request_id: 2 requests[0].request_id: 3,
}, requests[1].request_id: 2
total_num_scheduled_tokens=5, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=5,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [10, 42], scheduled_spec_decode_tokens={
requests[1].request_id: [13] requests[0].request_id: [10, 42],
}, requests[1].request_id: [13]
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
@ -551,23 +554,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req scheduler.requests[req.request_id] = req
scheduler.running.append(req) scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduler_output = SchedulerOutput(
scheduled_cached_reqs=[], scheduled_new_reqs=[],
num_scheduled_tokens={ scheduled_cached_reqs=CachedRequestData.make_empty(),
requests[0].request_id: 3, num_scheduled_tokens={
requests[1].request_id: 1 requests[0].request_id: 3,
}, requests[1].request_id: 1
total_num_scheduled_tokens=4, },
scheduled_encoder_inputs={}, total_num_scheduled_tokens=4,
scheduled_spec_decode_tokens={ scheduled_encoder_inputs={},
requests[0].request_id: [10, 11], scheduled_spec_decode_tokens={
requests[1].request_id: [] requests[0].request_id: [10, 11],
}, requests[1].request_id: []
num_common_prefix_blocks=0, },
finished_req_ids=set(), num_common_prefix_blocks=0,
free_encoder_input_ids=[], finished_req_ids=set(),
structured_output_request_ids={}, free_encoder_input_ids=[],
grammar_bitmask=None) structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput( model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests], req_ids=[req.request_id for req in requests],
@ -603,7 +608,7 @@ def test_stop_via_update_from_output():
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3}, num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3, total_num_scheduled_tokens=3,
scheduled_encoder_inputs={}, scheduled_encoder_inputs={},
@ -1208,7 +1213,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager. # EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0 assert len(scheduler.encoder_cache_manager.freed) == 0

View File

@ -66,7 +66,7 @@ def test_basic_lifecycle():
assert len(scheduler_output.finished_req_ids) == 1 assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model() # (2b): execute_model()
@ -81,7 +81,7 @@ def test_basic_lifecycle():
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0 assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model() # (3b): execute_model()

View File

@ -36,7 +36,7 @@ def test_basic_lifecycle():
# Nothing running and empty scheduler output. # Nothing running and empty scheduler output.
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0 assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0 assert scheduler_output.total_num_scheduled_tokens == 0
@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 1 assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b]) [request_local_a, request_local_b])
@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b]) reqs=[request_local_a, request_local_b])
@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
# STEP 4: KVs arrive. # STEP 4: KVs arrive.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b], [request_local_a, request_local_b],
@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0 assert len(scheduler.waiting) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output( model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote]) [request_local_a, request_local_b, request_remote])

View File

@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0 assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0 assert len(scheduler.finished_recving_kv_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager. # EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0 assert len(scheduler.encoder_cache_manager.freed) == 0

View File

@ -82,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput( return SchedulerOutput(
scheduled_new_reqs=new_reqs, scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -161,7 +161,7 @@ def test_update_states_request_finished(model_runner):
# finish req # finish req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={}, num_scheduled_tokens={},
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -191,7 +191,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req # unschedule req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={}, num_scheduled_tokens={},
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -209,16 +209,16 @@ def test_update_states_request_resumed(model_runner):
# resume req # resume req
cached_req_data = CachedRequestData( cached_req_data = CachedRequestData(
req_id=req_id, req_ids=[req_id],
resumed_from_preemption=False, resumed_from_preemption=[False],
new_token_ids=[], new_token_ids=[[]],
new_block_ids=([], ), new_block_ids=[([], )],
num_computed_tokens=0, num_computed_tokens=[0],
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data], scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1}, num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -249,7 +249,7 @@ def test_update_states_no_changes(model_runner):
# schedule req # schedule req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1}, num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -284,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1 # unschedule req_1
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1}, num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},

View File

@ -133,7 +133,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput( return SchedulerOutput(
scheduled_new_reqs=new_reqs, scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -199,7 +199,7 @@ def test_update_states_request_finished(model_runner):
# finish req # finish req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={}, num_scheduled_tokens={},
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -231,7 +231,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req # unschedule req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={}, num_scheduled_tokens={},
total_num_scheduled_tokens=0, total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -249,16 +249,16 @@ def test_update_states_request_resumed(model_runner):
# resume req # resume req
cached_req_data = CachedRequestData( cached_req_data = CachedRequestData(
req_id=req_id, req_ids=[req_id],
resumed_from_preemption=False, resumed_from_preemption=[False],
new_token_ids=[], new_token_ids=[[]],
new_block_ids=([], ), new_block_ids=([[0]], ),
num_computed_tokens=0, num_computed_tokens=[0],
) )
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data], scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1}, num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -339,7 +339,7 @@ def test_update_states_no_changes(model_runner):
# schedule req # schedule req
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1}, num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},
@ -376,7 +376,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1 # unschedule req_1
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=[], scheduled_new_reqs=[],
scheduled_cached_reqs=[], scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1}, num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1, total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={}, scheduled_spec_decode_tokens={},

View File

@ -371,45 +371,48 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_size=self._block_size) block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id) self._requests_need_load.pop(new_req.req_id)
for cached_req in scheduler_output.scheduled_cached_reqs: cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
if self.is_producer: if self.is_producer:
num_scheduled_tokens = ( num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[cached_req.req_id] scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_tokens = (num_scheduled_tokens + num_computed_tokens)
cached_req.num_computed_tokens) assert req_id in self.chunked_prefill
assert cached_req.req_id in self.chunked_prefill block_ids = new_block_ids[0]
block_ids = cached_req.new_block_ids[0] if not resumed_from_preemption:
if not cached_req.resumed_from_preemption: block_ids = (self.chunked_prefill[req_id][0] + block_ids)
block_ids = (self.chunked_prefill[cached_req.req_id][0] + prompt_token_ids = self.chunked_prefill[req_id][1]
block_ids)
prompt_token_ids = self.chunked_prefill[cached_req.req_id][1]
# the request's prompt is chunked prefill again # the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids): if num_tokens < len(prompt_token_ids):
self.chunked_prefill[cached_req.req_id] = ( self.chunked_prefill[req_id] = (block_ids,
block_ids, prompt_token_ids) prompt_token_ids)
continue continue
# the request's prompt is all prefilled finally # the request's prompt is all prefilled finally
meta.add_request(request_id=cached_req.req_id, meta.add_request(request_id=req_id,
token_ids=prompt_token_ids, token_ids=prompt_token_ids,
block_ids=block_ids, block_ids=block_ids,
block_size=self._block_size) block_size=self._block_size)
self.chunked_prefill.pop(cached_req.req_id, None) self.chunked_prefill.pop(req_id, None)
continue continue
# NOTE(rob): here we rely on the resumed requests being # NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs. # the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption: if not resumed_from_preemption:
break break
if cached_req.req_id in self._requests_need_load: if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(cached_req.req_id) request, _ = self._requests_need_load.pop(req_id)
total_tokens = cached_req.num_computed_tokens + 1 total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens] token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all # NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request. # of the block_ids for the request.
block_ids = cached_req.new_block_ids[0] block_ids = new_block_ids[0]
meta.add_request(request_id=cached_req.req_id, meta.add_request(request_id=req_id,
token_ids=token_ids, token_ids=token_ids,
block_ids=block_ids, block_ids=block_ids,
block_size=self._block_size) block_size=self._block_size)

View File

@ -304,23 +304,28 @@ class SharedStorageConnector(KVConnectorBase_V1):
block_size=self._block_size, block_size=self._block_size,
is_store=True) is_store=True)
for cached_req in scheduler_output.scheduled_cached_reqs: cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_token_ids = cached_reqs.new_token_ids[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
# NOTE(rob): here we rely on the resumed requests being # NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs. # the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption: if not resumed_from_preemption:
break break
if cached_req.req_id in self._requests_need_load: if req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full # NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it # list of token ids (only new tokens). So we look it
# up in the actual request object. # up in the actual request object.
request = self._requests_need_load[cached_req.req_id] request = self._requests_need_load[req_id]
total_tokens = (len(cached_req.new_token_ids) + total_tokens = (len(new_token_ids) + num_computed_tokens)
cached_req.num_computed_tokens)
token_ids = request.all_token_ids[:total_tokens] token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all # NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request. # of the block_ids for the request.
block_ids = cached_req.new_block_ids[0] block_ids = new_block_ids[0]
meta.add_request(token_ids=token_ids, meta.add_request(token_ids=token_ids,
block_ids=block_ids, block_ids=block_ids,

View File

@ -83,29 +83,27 @@ class NewRequestData:
@dataclass @dataclass
class CachedRequestData: class CachedRequestData:
req_id: str req_ids: list[str]
# If resumed_from_preemption is False, new_block_ids will be appended to # If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the # the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs. # request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool resumed_from_preemption: list[bool]
new_token_ids: list[int] new_token_ids: list[list[int]]
new_block_ids: tuple[list[int], ...] new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: int num_computed_tokens: list[int]
@property
def num_reqs(self) -> int:
return len(self.req_ids)
@classmethod @classmethod
def from_request( def make_empty(cls) -> CachedRequestData:
cls,
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: tuple[list[int], ...],
) -> CachedRequestData:
return cls( return cls(
req_id=request.request_id, req_ids=[],
resumed_from_preemption=resumed_from_preemption, resumed_from_preemption=[],
new_token_ids=new_token_ids, new_token_ids=[],
new_block_ids=new_block_ids, new_block_ids=[],
num_computed_tokens=request.num_computed_tokens, num_computed_tokens=[],
) )
@ -119,7 +117,7 @@ class SchedulerOutput:
# list of the requests that have been scheduled before. # list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes, # Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost. # we only send the diff to minimize the communication cost.
scheduled_cached_reqs: list[CachedRequestData] scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens # req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request. # Number of tokens scheduled for each request.

View File

@ -3,8 +3,9 @@
from __future__ import annotations from __future__ import annotations
import itertools
import time import time
from collections import defaultdict, deque from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface):
# KV Connector: requests in process of async KV loading or recving # KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set() self.finished_recving_kv_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> deque of CachedRequestData
self._cached_reqs_data: dict[
str, deque[CachedRequestData]] = defaultdict(deque)
# Encoder-related. # Encoder-related.
# Calculate encoder cache size if applicable # Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space. # NOTE: For now we use the same budget for both compute and space.
@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface):
req_to_new_block_ids[req.request_id]) req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs for req in scheduled_new_reqs
] ]
resumed_reqs_data = [ cached_reqs_data = self._make_cached_request_data(
self._make_cached_request_data( scheduled_running_reqs,
req, scheduled_resumed_reqs,
num_scheduled_tokens[req.request_id], num_scheduled_tokens,
len(scheduled_spec_decode_tokens.get(req.request_id, ())), scheduled_spec_decode_tokens,
req_to_new_block_ids[req.request_id], req_to_new_block_ids,
resumed_from_preemption=True, )
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
scheduler_output = SchedulerOutput( scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data, scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data, scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface):
def _make_cached_request_data( def _make_cached_request_data(
self, self,
request: Request, running_reqs: list[Request],
num_scheduled_tokens: int, resumed_reqs: list[Request],
num_scheduled_spec_tokens: int, num_scheduled_tokens: dict[str, int],
new_block_ids: tuple[list[int], ...], spec_decode_tokens: dict[str, list[int]],
resumed_from_preemption: bool, req_to_new_block_ids: dict[str, tuple[list[int], ...]],
) -> CachedRequestData: ) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating req_ids: list[str] = []
# them at each scheduling step. new_token_ids: list[list[int]] = []
num_computed_tokens = request.num_computed_tokens new_block_ids: list[tuple[list[int], ...]] = []
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens num_computed_tokens: list[int] = []
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data_queue = self._cached_reqs_data.get(request.request_id) for req in itertools.chain(running_reqs, resumed_reqs):
if req_data_queue: req_id = req.request_id
req_data = req_data_queue.popleft() req_ids.append(req_id)
req_data.resumed_from_preemption = resumed_from_preemption num_tokens = (num_scheduled_tokens[req_id] -
req_data.new_token_ids = new_token_ids len(spec_decode_tokens.get(req_id, ())))
req_data.new_block_ids = new_block_ids token_ids = req.all_token_ids[req.num_computed_tokens:req.
req_data.num_computed_tokens = num_computed_tokens num_computed_tokens + num_tokens]
else: new_token_ids.append(token_ids)
# No cached request data, or all cached request data has been new_block_ids.append(req_to_new_block_ids[req_id])
# used by the scheduled requests. num_computed_tokens.append(req.num_computed_tokens)
req_data = CachedRequestData.from_request(request, # Because resumed_reqs is usually empty, it is more efficient to do
resumed_from_preemption, # in-place appending so that we don't need to allocate a new list.
new_token_ids, resumed_from_preemption = [False] * len(running_reqs)
new_block_ids) resumed_from_preemption += [True] * len(resumed_reqs)
return req_data
return CachedRequestData(
req_ids=req_ids,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)
def _try_schedule_encoder_inputs( def _try_schedule_encoder_inputs(
self, self,
@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface):
if not stopped: if not stopped:
new_running.append(request) new_running.append(request)
self.running = new_running
# KV Connector: update state for finished KV Transfers. # KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output) self._update_from_kv_xfer_finished(model_runner_output)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
# to _cached_reqs_data will cause a memory leak.
if req_data.req_id not in self.finished_req_ids:
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running
# Create EngineCoreOutputs for all clients that have requests with # Create EngineCoreOutputs for all clients that have requests with
# outputs in this step. # outputs in this step.
engine_core_outputs = { engine_core_outputs = {
@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface):
self._free_request(request) self._free_request(request)
def _free_request(self, request: Request) -> Optional[dict[str, Any]]: def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished() assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request) delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request) self.encoder_cache_manager.free(request)
request_id = request.request_id request_id = request.request_id
self._cached_reqs_data.pop(request_id, None)
self.finished_req_ids.add(request_id) self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None: if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id) self.finished_req_ids_dict[request.client_index].add(request_id)
@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface):
def _free_blocks(self, request: Request): def _free_blocks(self, request: Request):
assert request.is_finished() assert request.is_finished()
assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request) self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request) self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id] del self.requests[request.request_id]

View File

@ -470,34 +470,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs: req_data = scheduler_output.scheduled_cached_reqs
req_id = req_data.req_id for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_token_ids = req_data.new_token_ids[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states. # Update the cached states.
num_computed_tokens = req_data.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any). # Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens. # This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens + num_new_tokens = (num_computed_tokens + len(new_token_ids) -
len(req_data.new_token_ids) -
req_state.num_tokens) req_state.num_tokens)
if num_new_tokens == 1: if num_new_tokens == 1:
# Avoid slicing list in most common case. # Avoid slicing list in most common case.
req_state.output_token_ids.append(req_data.new_token_ids[-1]) req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0: elif num_new_tokens > 0:
req_state.output_token_ids.extend( req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:]) new_token_ids[-num_new_tokens:])
# Update the block IDs. # Update the block IDs.
if not req_data.resumed_from_preemption: if not resumed_from_preemption:
# Append the new blocks to the existing block IDs. # Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip(req_state.block_ids, for block_ids, new_ids in zip(req_state.block_ids,
req_data.new_block_ids): new_block_ids):
block_ids.extend(new_block_ids) block_ids.extend(new_ids)
else: else:
# The request is resumed from preemption. # The request is resumed from preemption.
# Replace the existing block IDs with the new ones. # Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id) req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None: if req_index is None:
@ -510,14 +512,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids, self.input_batch.block_table.append_row(new_block_ids, req_index)
req_index)
# Add new_token_ids to token_ids_cpu. # Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids) end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[
req_index, req_index, start_token_index:end_token_index] = new_token_ids
start_token_index:end_token_index] = req_data.new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu. # Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(

View File

@ -418,21 +418,24 @@ class TPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id) req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests. # Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs: req_data = scheduler_output.scheduled_cached_reqs
req_id = req_data.req_id for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id] req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states. # Update the cached states.
req_state.num_computed_tokens = req_data.num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
if not req_data.resumed_from_preemption: if not resumed_from_preemption:
# Append the new blocks to the existing block IDs. # Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip(req_state.block_ids, for block_ids, new_ids in zip(req_state.block_ids,
req_data.new_block_ids): new_block_ids):
block_ids.extend(new_block_ids) block_ids.extend(new_ids)
else: else:
# The request is resumed from preemption. # The request is resumed from preemption.
# Replace the existing block IDs with the new ones. # Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id) req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None: if req_index is None:
@ -444,9 +447,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch. # Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = ( self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens) num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids, self.input_batch.block_table.append_row(new_block_ids, req_index)
req_index)
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.