mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:45:01 +08:00
[Optimization] Use Shared CachedRequestData Instance Across All Requests (#20232)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
2965c99c86
commit
2863befce3
@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
|||||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec)
|
KVCacheGroupSpec)
|
||||||
@ -198,7 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
|
|||||||
# Test initial scheduling
|
# Test initial scheduling
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == len(requests)
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
assert len(output.scheduled_cached_reqs) == 0
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
# Verify all requests are scheduled.
|
# Verify all requests are scheduled.
|
||||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
@ -225,7 +225,7 @@ def test_schedule_multimodal_requests():
|
|||||||
|
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == len(requests)
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
assert len(output.scheduled_cached_reqs) == 0
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||||
@ -259,7 +259,7 @@ def test_schedule_partial_requests():
|
|||||||
|
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == 3
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
assert len(output.scheduled_cached_reqs) == 0
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
|
|
||||||
assert scheduler.max_num_encoder_input_tokens == 1024
|
assert scheduler.max_num_encoder_input_tokens == 1024
|
||||||
@ -295,7 +295,7 @@ def test_schedule_partial_requests():
|
|||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 3
|
assert len(scheduler.running) == 3
|
||||||
assert len(output.scheduled_new_reqs) == 0
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
assert len(output.scheduled_cached_reqs) == 2
|
assert output.scheduled_cached_reqs.num_reqs == 2
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
assert output.num_scheduled_tokens[requests[0].request_id] == 1
|
assert output.num_scheduled_tokens[requests[0].request_id] == 1
|
||||||
assert output.num_scheduled_tokens[requests[1].request_id] == 700
|
assert output.num_scheduled_tokens[requests[1].request_id] == 700
|
||||||
@ -319,7 +319,7 @@ def test_no_mm_input_chunking():
|
|||||||
|
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == 1
|
assert len(output.scheduled_new_reqs) == 1
|
||||||
assert len(output.scheduled_cached_reqs) == 0
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
# We want to only see the 400 text tokens at the start scheduled
|
# We want to only see the 400 text tokens at the start scheduled
|
||||||
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
||||||
@ -342,7 +342,7 @@ def test_no_mm_input_chunking():
|
|||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(output.scheduled_new_reqs) == 0
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
assert len(output.scheduled_cached_reqs) == 1
|
assert output.scheduled_cached_reqs.num_reqs == 1
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
assert output.num_scheduled_tokens[requests[0].request_id] == 800
|
||||||
|
|
||||||
@ -379,7 +379,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
|
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == 3
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
assert len(output.scheduled_cached_reqs) == 0
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
|
|
||||||
# The first request is scheduled partially - 400.
|
# The first request is scheduled partially - 400.
|
||||||
@ -408,7 +408,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
output1 = scheduler.schedule()
|
output1 = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 3
|
assert len(scheduler.running) == 3
|
||||||
assert len(output1.scheduled_new_reqs) == 0
|
assert len(output1.scheduled_new_reqs) == 0
|
||||||
assert len(output1.scheduled_cached_reqs) == 3
|
assert output1.scheduled_cached_reqs.num_reqs == 3
|
||||||
assert len(output1.finished_req_ids) == 0
|
assert len(output1.finished_req_ids) == 0
|
||||||
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
||||||
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
||||||
@ -430,7 +430,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
output2 = scheduler.schedule()
|
output2 = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 3
|
assert len(scheduler.running) == 3
|
||||||
assert len(output2.scheduled_new_reqs) == 0
|
assert len(output2.scheduled_new_reqs) == 0
|
||||||
assert len(output2.scheduled_cached_reqs) == 3
|
assert output2.scheduled_cached_reqs.num_reqs == 3
|
||||||
assert len(output2.finished_req_ids) == 0
|
assert len(output2.finished_req_ids) == 0
|
||||||
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||||
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
||||||
@ -449,23 +449,24 @@ def test_stop_via_update_from_output():
|
|||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_cached_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
num_scheduled_tokens={
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
requests[0].request_id: 1,
|
num_scheduled_tokens={
|
||||||
requests[1].request_id: 2
|
requests[0].request_id: 1,
|
||||||
},
|
requests[1].request_id: 2
|
||||||
total_num_scheduled_tokens=3,
|
},
|
||||||
scheduled_encoder_inputs={},
|
total_num_scheduled_tokens=3,
|
||||||
scheduled_spec_decode_tokens={
|
scheduled_encoder_inputs={},
|
||||||
requests[0].request_id: [],
|
scheduled_spec_decode_tokens={
|
||||||
requests[1].request_id: [10]
|
requests[0].request_id: [],
|
||||||
},
|
requests[1].request_id: [10]
|
||||||
num_common_prefix_blocks=0,
|
},
|
||||||
finished_req_ids=set(),
|
num_common_prefix_blocks=0,
|
||||||
free_encoder_input_ids=[],
|
finished_req_ids=set(),
|
||||||
structured_output_request_ids={},
|
free_encoder_input_ids=[],
|
||||||
grammar_bitmask=None)
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
req_ids=[req.request_id for req in requests],
|
req_ids=[req.request_id for req in requests],
|
||||||
@ -501,23 +502,25 @@ def test_stop_via_update_from_output():
|
|||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_cached_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
num_scheduled_tokens={
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
requests[0].request_id: 3,
|
num_scheduled_tokens={
|
||||||
requests[1].request_id: 2
|
requests[0].request_id: 3,
|
||||||
},
|
requests[1].request_id: 2
|
||||||
total_num_scheduled_tokens=5,
|
},
|
||||||
scheduled_encoder_inputs={},
|
total_num_scheduled_tokens=5,
|
||||||
scheduled_spec_decode_tokens={
|
scheduled_encoder_inputs={},
|
||||||
requests[0].request_id: [10, 42],
|
scheduled_spec_decode_tokens={
|
||||||
requests[1].request_id: [13]
|
requests[0].request_id: [10, 42],
|
||||||
},
|
requests[1].request_id: [13]
|
||||||
num_common_prefix_blocks=0,
|
},
|
||||||
finished_req_ids=set(),
|
num_common_prefix_blocks=0,
|
||||||
free_encoder_input_ids=[],
|
finished_req_ids=set(),
|
||||||
structured_output_request_ids={},
|
free_encoder_input_ids=[],
|
||||||
grammar_bitmask=None)
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
|
)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
req_ids=[req.request_id for req in requests],
|
req_ids=[req.request_id for req in requests],
|
||||||
@ -551,23 +554,25 @@ def test_stop_via_update_from_output():
|
|||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_cached_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
num_scheduled_tokens={
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
requests[0].request_id: 3,
|
num_scheduled_tokens={
|
||||||
requests[1].request_id: 1
|
requests[0].request_id: 3,
|
||||||
},
|
requests[1].request_id: 1
|
||||||
total_num_scheduled_tokens=4,
|
},
|
||||||
scheduled_encoder_inputs={},
|
total_num_scheduled_tokens=4,
|
||||||
scheduled_spec_decode_tokens={
|
scheduled_encoder_inputs={},
|
||||||
requests[0].request_id: [10, 11],
|
scheduled_spec_decode_tokens={
|
||||||
requests[1].request_id: []
|
requests[0].request_id: [10, 11],
|
||||||
},
|
requests[1].request_id: []
|
||||||
num_common_prefix_blocks=0,
|
},
|
||||||
finished_req_ids=set(),
|
num_common_prefix_blocks=0,
|
||||||
free_encoder_input_ids=[],
|
finished_req_ids=set(),
|
||||||
structured_output_request_ids={},
|
free_encoder_input_ids=[],
|
||||||
grammar_bitmask=None)
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
|
)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
req_ids=[req.request_id for req in requests],
|
req_ids=[req.request_id for req in requests],
|
||||||
@ -603,7 +608,7 @@ def test_stop_via_update_from_output():
|
|||||||
|
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={requests[0].request_id: 3},
|
num_scheduled_tokens={requests[0].request_id: 3},
|
||||||
total_num_scheduled_tokens=3,
|
total_num_scheduled_tokens=3,
|
||||||
scheduled_encoder_inputs={},
|
scheduled_encoder_inputs={},
|
||||||
@ -1208,7 +1213,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
|||||||
assert len(scheduler.waiting) == 0
|
assert len(scheduler.waiting) == 0
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.finished_req_ids) == 0
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
assert len(scheduler._cached_reqs_data) == 0
|
|
||||||
|
|
||||||
# EncoderCacheManager.
|
# EncoderCacheManager.
|
||||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||||
|
|||||||
@ -66,7 +66,7 @@ def test_basic_lifecycle():
|
|||||||
assert len(scheduler_output.finished_req_ids) == 1
|
assert len(scheduler_output.finished_req_ids) == 1
|
||||||
assert request_id in scheduler_output.finished_req_ids
|
assert request_id in scheduler_output.finished_req_ids
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert len(scheduler_output.scheduled_cached_reqs) == 0
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(scheduler.finished_req_ids) == 0
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
|
|
||||||
# (2b): execute_model()
|
# (2b): execute_model()
|
||||||
@ -81,7 +81,7 @@ def test_basic_lifecycle():
|
|||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler_output.finished_req_ids) == 0
|
assert len(scheduler_output.finished_req_ids) == 0
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert len(scheduler_output.scheduled_cached_reqs) == 0
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(scheduler.finished_req_ids) == 0
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
|
|
||||||
# (3b): execute_model()
|
# (3b): execute_model()
|
||||||
|
|||||||
@ -36,7 +36,7 @@ def test_basic_lifecycle():
|
|||||||
# Nothing running and empty scheduler output.
|
# Nothing running and empty scheduler output.
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert len(scheduler_output.scheduled_cached_reqs) == 0
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(scheduler_output.num_scheduled_tokens) == 0
|
assert len(scheduler_output.num_scheduled_tokens) == 0
|
||||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||||
|
|
||||||
@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
|
|||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
assert len(scheduler_output.scheduled_cached_reqs) == 1
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
|
||||||
|
|
||||||
model_runner_output = create_model_runner_output(
|
model_runner_output = create_model_runner_output(
|
||||||
[request_local_a, request_local_b])
|
[request_local_a, request_local_b])
|
||||||
@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
|
|||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert len(scheduler_output.scheduled_cached_reqs) == 2
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||||
|
|
||||||
model_runner_output = create_model_runner_output(
|
model_runner_output = create_model_runner_output(
|
||||||
reqs=[request_local_a, request_local_b])
|
reqs=[request_local_a, request_local_b])
|
||||||
@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
|
|||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert len(scheduler_output.scheduled_cached_reqs) == 2
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||||
|
|
||||||
# STEP 4: KVs arrive.
|
# STEP 4: KVs arrive.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert len(scheduler_output.scheduled_cached_reqs) == 2
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||||
|
|
||||||
model_runner_output = create_model_runner_output(
|
model_runner_output = create_model_runner_output(
|
||||||
[request_local_a, request_local_b],
|
[request_local_a, request_local_b],
|
||||||
@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
|
|||||||
assert len(scheduler.running) == 3
|
assert len(scheduler.running) == 3
|
||||||
assert len(scheduler.waiting) == 0
|
assert len(scheduler.waiting) == 0
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
assert len(scheduler_output.scheduled_cached_reqs) == 2
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||||
|
|
||||||
model_runner_output = create_model_runner_output(
|
model_runner_output = create_model_runner_output(
|
||||||
[request_local_a, request_local_b, request_remote])
|
[request_local_a, request_local_b, request_remote])
|
||||||
|
|||||||
@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
|
|||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.finished_req_ids) == 0
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
||||||
assert len(scheduler._cached_reqs_data) == 0
|
|
||||||
|
|
||||||
# EncoderCacheManager.
|
# EncoderCacheManager.
|
||||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||||
|
|||||||
@ -82,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
|
|
||||||
return SchedulerOutput(
|
return SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs,
|
scheduled_new_reqs=new_reqs,
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -161,7 +161,7 @@ def test_update_states_request_finished(model_runner):
|
|||||||
# finish req
|
# finish req
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={},
|
num_scheduled_tokens={},
|
||||||
total_num_scheduled_tokens=0,
|
total_num_scheduled_tokens=0,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -191,7 +191,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
# unschedule req
|
# unschedule req
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={},
|
num_scheduled_tokens={},
|
||||||
total_num_scheduled_tokens=0,
|
total_num_scheduled_tokens=0,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -209,16 +209,16 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
|
|
||||||
# resume req
|
# resume req
|
||||||
cached_req_data = CachedRequestData(
|
cached_req_data = CachedRequestData(
|
||||||
req_id=req_id,
|
req_ids=[req_id],
|
||||||
resumed_from_preemption=False,
|
resumed_from_preemption=[False],
|
||||||
new_token_ids=[],
|
new_token_ids=[[]],
|
||||||
new_block_ids=([], ),
|
new_block_ids=[([], )],
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[cached_req_data],
|
scheduled_cached_reqs=cached_req_data,
|
||||||
num_scheduled_tokens={req_id: 1},
|
num_scheduled_tokens={req_id: 1},
|
||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -249,7 +249,7 @@ def test_update_states_no_changes(model_runner):
|
|||||||
# schedule req
|
# schedule req
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={req_id: 1},
|
num_scheduled_tokens={req_id: 1},
|
||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -284,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
# unschedule req_1
|
# unschedule req_1
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={req_ids[0]: 1},
|
num_scheduled_tokens={req_ids[0]: 1},
|
||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
|
|||||||
@ -133,7 +133,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
|
|
||||||
return SchedulerOutput(
|
return SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs,
|
scheduled_new_reqs=new_reqs,
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -199,7 +199,7 @@ def test_update_states_request_finished(model_runner):
|
|||||||
# finish req
|
# finish req
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={},
|
num_scheduled_tokens={},
|
||||||
total_num_scheduled_tokens=0,
|
total_num_scheduled_tokens=0,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -231,7 +231,7 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
# unschedule req
|
# unschedule req
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={},
|
num_scheduled_tokens={},
|
||||||
total_num_scheduled_tokens=0,
|
total_num_scheduled_tokens=0,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -249,16 +249,16 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
|
|
||||||
# resume req
|
# resume req
|
||||||
cached_req_data = CachedRequestData(
|
cached_req_data = CachedRequestData(
|
||||||
req_id=req_id,
|
req_ids=[req_id],
|
||||||
resumed_from_preemption=False,
|
resumed_from_preemption=[False],
|
||||||
new_token_ids=[],
|
new_token_ids=[[]],
|
||||||
new_block_ids=([], ),
|
new_block_ids=([[0]], ),
|
||||||
num_computed_tokens=0,
|
num_computed_tokens=[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[cached_req_data],
|
scheduled_cached_reqs=cached_req_data,
|
||||||
num_scheduled_tokens={req_id: 1},
|
num_scheduled_tokens={req_id: 1},
|
||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -339,7 +339,7 @@ def test_update_states_no_changes(model_runner):
|
|||||||
# schedule req
|
# schedule req
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={req_id: 1},
|
num_scheduled_tokens={req_id: 1},
|
||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
@ -376,7 +376,7 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
# unschedule req_1
|
# unschedule req_1
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=[],
|
scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||||
num_scheduled_tokens={req_ids[0]: 1},
|
num_scheduled_tokens={req_ids[0]: 1},
|
||||||
total_num_scheduled_tokens=1,
|
total_num_scheduled_tokens=1,
|
||||||
scheduled_spec_decode_tokens={},
|
scheduled_spec_decode_tokens={},
|
||||||
|
|||||||
@ -371,45 +371,48 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
|||||||
block_size=self._block_size)
|
block_size=self._block_size)
|
||||||
self._requests_need_load.pop(new_req.req_id)
|
self._requests_need_load.pop(new_req.req_id)
|
||||||
|
|
||||||
for cached_req in scheduler_output.scheduled_cached_reqs:
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
|
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||||
|
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||||
|
new_block_ids = cached_reqs.new_block_ids[i]
|
||||||
|
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
|
||||||
|
|
||||||
if self.is_producer:
|
if self.is_producer:
|
||||||
num_scheduled_tokens = (
|
num_scheduled_tokens = (
|
||||||
scheduler_output.num_scheduled_tokens)[cached_req.req_id]
|
scheduler_output.num_scheduled_tokens)[req_id]
|
||||||
num_tokens = (num_scheduled_tokens +
|
num_tokens = (num_scheduled_tokens + num_computed_tokens)
|
||||||
cached_req.num_computed_tokens)
|
assert req_id in self.chunked_prefill
|
||||||
assert cached_req.req_id in self.chunked_prefill
|
block_ids = new_block_ids[0]
|
||||||
block_ids = cached_req.new_block_ids[0]
|
if not resumed_from_preemption:
|
||||||
if not cached_req.resumed_from_preemption:
|
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
|
||||||
block_ids = (self.chunked_prefill[cached_req.req_id][0] +
|
prompt_token_ids = self.chunked_prefill[req_id][1]
|
||||||
block_ids)
|
|
||||||
prompt_token_ids = self.chunked_prefill[cached_req.req_id][1]
|
|
||||||
# the request's prompt is chunked prefill again
|
# the request's prompt is chunked prefill again
|
||||||
if num_tokens < len(prompt_token_ids):
|
if num_tokens < len(prompt_token_ids):
|
||||||
self.chunked_prefill[cached_req.req_id] = (
|
self.chunked_prefill[req_id] = (block_ids,
|
||||||
block_ids, prompt_token_ids)
|
prompt_token_ids)
|
||||||
continue
|
continue
|
||||||
# the request's prompt is all prefilled finally
|
# the request's prompt is all prefilled finally
|
||||||
meta.add_request(request_id=cached_req.req_id,
|
meta.add_request(request_id=req_id,
|
||||||
token_ids=prompt_token_ids,
|
token_ids=prompt_token_ids,
|
||||||
block_ids=block_ids,
|
block_ids=block_ids,
|
||||||
block_size=self._block_size)
|
block_size=self._block_size)
|
||||||
self.chunked_prefill.pop(cached_req.req_id, None)
|
self.chunked_prefill.pop(req_id, None)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# NOTE(rob): here we rely on the resumed requests being
|
# NOTE(rob): here we rely on the resumed requests being
|
||||||
# the first N requests in the list scheduled_cache_reqs.
|
# the first N requests in the list scheduled_cache_reqs.
|
||||||
if not cached_req.resumed_from_preemption:
|
if not resumed_from_preemption:
|
||||||
break
|
break
|
||||||
if cached_req.req_id in self._requests_need_load:
|
if req_id in self._requests_need_load:
|
||||||
request, _ = self._requests_need_load.pop(cached_req.req_id)
|
request, _ = self._requests_need_load.pop(req_id)
|
||||||
total_tokens = cached_req.num_computed_tokens + 1
|
total_tokens = num_computed_tokens + 1
|
||||||
token_ids = request.all_token_ids[:total_tokens]
|
token_ids = request.all_token_ids[:total_tokens]
|
||||||
|
|
||||||
# NOTE(rob): For resumed req, new_block_ids is all
|
# NOTE(rob): For resumed req, new_block_ids is all
|
||||||
# of the block_ids for the request.
|
# of the block_ids for the request.
|
||||||
block_ids = cached_req.new_block_ids[0]
|
block_ids = new_block_ids[0]
|
||||||
|
|
||||||
meta.add_request(request_id=cached_req.req_id,
|
meta.add_request(request_id=req_id,
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
block_ids=block_ids,
|
block_ids=block_ids,
|
||||||
block_size=self._block_size)
|
block_size=self._block_size)
|
||||||
|
|||||||
@ -304,23 +304,28 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
block_size=self._block_size,
|
block_size=self._block_size,
|
||||||
is_store=True)
|
is_store=True)
|
||||||
|
|
||||||
for cached_req in scheduler_output.scheduled_cached_reqs:
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
|
for i, req_id in enumerate(cached_reqs.req_ids):
|
||||||
|
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||||||
|
new_token_ids = cached_reqs.new_token_ids[i]
|
||||||
|
new_block_ids = cached_reqs.new_block_ids[i]
|
||||||
|
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
|
||||||
|
|
||||||
# NOTE(rob): here we rely on the resumed requests being
|
# NOTE(rob): here we rely on the resumed requests being
|
||||||
# the first N requests in the list scheduled_cache_reqs.
|
# the first N requests in the list scheduled_cache_reqs.
|
||||||
if not cached_req.resumed_from_preemption:
|
if not resumed_from_preemption:
|
||||||
break
|
break
|
||||||
if cached_req.req_id in self._requests_need_load:
|
if req_id in self._requests_need_load:
|
||||||
# NOTE(rob): cached_req_data does not have the full
|
# NOTE(rob): cached_req_data does not have the full
|
||||||
# list of token ids (only new tokens). So we look it
|
# list of token ids (only new tokens). So we look it
|
||||||
# up in the actual request object.
|
# up in the actual request object.
|
||||||
request = self._requests_need_load[cached_req.req_id]
|
request = self._requests_need_load[req_id]
|
||||||
total_tokens = (len(cached_req.new_token_ids) +
|
total_tokens = (len(new_token_ids) + num_computed_tokens)
|
||||||
cached_req.num_computed_tokens)
|
|
||||||
token_ids = request.all_token_ids[:total_tokens]
|
token_ids = request.all_token_ids[:total_tokens]
|
||||||
|
|
||||||
# NOTE(rob): For resumed req, new_block_ids is all
|
# NOTE(rob): For resumed req, new_block_ids is all
|
||||||
# of the block_ids for the request.
|
# of the block_ids for the request.
|
||||||
block_ids = cached_req.new_block_ids[0]
|
block_ids = new_block_ids[0]
|
||||||
|
|
||||||
meta.add_request(token_ids=token_ids,
|
meta.add_request(token_ids=token_ids,
|
||||||
block_ids=block_ids,
|
block_ids=block_ids,
|
||||||
|
|||||||
@ -83,29 +83,27 @@ class NewRequestData:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class CachedRequestData:
|
class CachedRequestData:
|
||||||
|
|
||||||
req_id: str
|
req_ids: list[str]
|
||||||
# If resumed_from_preemption is False, new_block_ids will be appended to
|
# If resumed_from_preemption is False, new_block_ids will be appended to
|
||||||
# the request's block IDs. If True, new_block_ids will be used as the
|
# the request's block IDs. If True, new_block_ids will be used as the
|
||||||
# request's block IDs instead of appending to the existing block IDs.
|
# request's block IDs instead of appending to the existing block IDs.
|
||||||
resumed_from_preemption: bool
|
resumed_from_preemption: list[bool]
|
||||||
new_token_ids: list[int]
|
new_token_ids: list[list[int]]
|
||||||
new_block_ids: tuple[list[int], ...]
|
new_block_ids: list[tuple[list[int], ...]]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: list[int]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_reqs(self) -> int:
|
||||||
|
return len(self.req_ids)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_request(
|
def make_empty(cls) -> CachedRequestData:
|
||||||
cls,
|
|
||||||
request: Request,
|
|
||||||
resumed_from_preemption: bool,
|
|
||||||
new_token_ids: list[int],
|
|
||||||
new_block_ids: tuple[list[int], ...],
|
|
||||||
) -> CachedRequestData:
|
|
||||||
return cls(
|
return cls(
|
||||||
req_id=request.request_id,
|
req_ids=[],
|
||||||
resumed_from_preemption=resumed_from_preemption,
|
resumed_from_preemption=[],
|
||||||
new_token_ids=new_token_ids,
|
new_token_ids=[],
|
||||||
new_block_ids=new_block_ids,
|
new_block_ids=[],
|
||||||
num_computed_tokens=request.num_computed_tokens,
|
num_computed_tokens=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -119,7 +117,7 @@ class SchedulerOutput:
|
|||||||
# list of the requests that have been scheduled before.
|
# list of the requests that have been scheduled before.
|
||||||
# Since the request's data is already cached in the worker processes,
|
# Since the request's data is already cached in the worker processes,
|
||||||
# we only send the diff to minimize the communication cost.
|
# we only send the diff to minimize the communication cost.
|
||||||
scheduled_cached_reqs: list[CachedRequestData]
|
scheduled_cached_reqs: CachedRequestData
|
||||||
|
|
||||||
# req_id -> num_scheduled_tokens
|
# req_id -> num_scheduled_tokens
|
||||||
# Number of tokens scheduled for each request.
|
# Number of tokens scheduled for each request.
|
||||||
|
|||||||
@ -3,8 +3,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import itertools
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
# KV Connector: requests in process of async KV loading or recving
|
# KV Connector: requests in process of async KV loading or recving
|
||||||
self.finished_recving_kv_req_ids: set[str] = set()
|
self.finished_recving_kv_req_ids: set[str] = set()
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
|
||||||
# them at each scheduling step.
|
|
||||||
# Request id -> deque of CachedRequestData
|
|
||||||
self._cached_reqs_data: dict[
|
|
||||||
str, deque[CachedRequestData]] = defaultdict(deque)
|
|
||||||
|
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
# Calculate encoder cache size if applicable
|
# Calculate encoder cache size if applicable
|
||||||
# NOTE: For now we use the same budget for both compute and space.
|
# NOTE: For now we use the same budget for both compute and space.
|
||||||
@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface):
|
|||||||
req_to_new_block_ids[req.request_id])
|
req_to_new_block_ids[req.request_id])
|
||||||
for req in scheduled_new_reqs
|
for req in scheduled_new_reqs
|
||||||
]
|
]
|
||||||
resumed_reqs_data = [
|
cached_reqs_data = self._make_cached_request_data(
|
||||||
self._make_cached_request_data(
|
scheduled_running_reqs,
|
||||||
req,
|
scheduled_resumed_reqs,
|
||||||
num_scheduled_tokens[req.request_id],
|
num_scheduled_tokens,
|
||||||
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
scheduled_spec_decode_tokens,
|
||||||
req_to_new_block_ids[req.request_id],
|
req_to_new_block_ids,
|
||||||
resumed_from_preemption=True,
|
)
|
||||||
) for req in scheduled_resumed_reqs
|
|
||||||
]
|
|
||||||
running_reqs_data = [
|
|
||||||
self._make_cached_request_data(
|
|
||||||
req,
|
|
||||||
num_scheduled_tokens[req.request_id],
|
|
||||||
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
|
|
||||||
req_to_new_block_ids[req.request_id],
|
|
||||||
resumed_from_preemption=False,
|
|
||||||
) for req in scheduled_running_reqs
|
|
||||||
]
|
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs_data,
|
scheduled_new_reqs=new_reqs_data,
|
||||||
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
scheduled_cached_reqs=cached_reqs_data,
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||||
@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
def _make_cached_request_data(
|
def _make_cached_request_data(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
running_reqs: list[Request],
|
||||||
num_scheduled_tokens: int,
|
resumed_reqs: list[Request],
|
||||||
num_scheduled_spec_tokens: int,
|
num_scheduled_tokens: dict[str, int],
|
||||||
new_block_ids: tuple[list[int], ...],
|
spec_decode_tokens: dict[str, list[int]],
|
||||||
resumed_from_preemption: bool,
|
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
|
||||||
) -> CachedRequestData:
|
) -> CachedRequestData:
|
||||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
req_ids: list[str] = []
|
||||||
# them at each scheduling step.
|
new_token_ids: list[list[int]] = []
|
||||||
num_computed_tokens = request.num_computed_tokens
|
new_block_ids: list[tuple[list[int], ...]] = []
|
||||||
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
|
num_computed_tokens: list[int] = []
|
||||||
new_token_ids = request.all_token_ids[
|
|
||||||
num_computed_tokens:num_computed_tokens + num_regular_tokens]
|
|
||||||
|
|
||||||
req_data_queue = self._cached_reqs_data.get(request.request_id)
|
for req in itertools.chain(running_reqs, resumed_reqs):
|
||||||
if req_data_queue:
|
req_id = req.request_id
|
||||||
req_data = req_data_queue.popleft()
|
req_ids.append(req_id)
|
||||||
req_data.resumed_from_preemption = resumed_from_preemption
|
num_tokens = (num_scheduled_tokens[req_id] -
|
||||||
req_data.new_token_ids = new_token_ids
|
len(spec_decode_tokens.get(req_id, ())))
|
||||||
req_data.new_block_ids = new_block_ids
|
token_ids = req.all_token_ids[req.num_computed_tokens:req.
|
||||||
req_data.num_computed_tokens = num_computed_tokens
|
num_computed_tokens + num_tokens]
|
||||||
else:
|
new_token_ids.append(token_ids)
|
||||||
# No cached request data, or all cached request data has been
|
new_block_ids.append(req_to_new_block_ids[req_id])
|
||||||
# used by the scheduled requests.
|
num_computed_tokens.append(req.num_computed_tokens)
|
||||||
req_data = CachedRequestData.from_request(request,
|
# Because resumed_reqs is usually empty, it is more efficient to do
|
||||||
resumed_from_preemption,
|
# in-place appending so that we don't need to allocate a new list.
|
||||||
new_token_ids,
|
resumed_from_preemption = [False] * len(running_reqs)
|
||||||
new_block_ids)
|
resumed_from_preemption += [True] * len(resumed_reqs)
|
||||||
return req_data
|
|
||||||
|
return CachedRequestData(
|
||||||
|
req_ids=req_ids,
|
||||||
|
resumed_from_preemption=resumed_from_preemption,
|
||||||
|
new_token_ids=new_token_ids,
|
||||||
|
new_block_ids=new_block_ids,
|
||||||
|
num_computed_tokens=num_computed_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
def _try_schedule_encoder_inputs(
|
def _try_schedule_encoder_inputs(
|
||||||
self,
|
self,
|
||||||
@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
if not stopped:
|
if not stopped:
|
||||||
new_running.append(request)
|
new_running.append(request)
|
||||||
|
self.running = new_running
|
||||||
|
|
||||||
# KV Connector: update state for finished KV Transfers.
|
# KV Connector: update state for finished KV Transfers.
|
||||||
self._update_from_kv_xfer_finished(model_runner_output)
|
self._update_from_kv_xfer_finished(model_runner_output)
|
||||||
|
|
||||||
# Return the cached request data to the queue so they can be reused.
|
|
||||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
|
||||||
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
|
|
||||||
# to _cached_reqs_data will cause a memory leak.
|
|
||||||
if req_data.req_id not in self.finished_req_ids:
|
|
||||||
self._cached_reqs_data[req_data.req_id].append(req_data)
|
|
||||||
|
|
||||||
self.running = new_running
|
|
||||||
|
|
||||||
# Create EngineCoreOutputs for all clients that have requests with
|
# Create EngineCoreOutputs for all clients that have requests with
|
||||||
# outputs in this step.
|
# outputs in this step.
|
||||||
engine_core_outputs = {
|
engine_core_outputs = {
|
||||||
@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
self._free_request(request)
|
self._free_request(request)
|
||||||
|
|
||||||
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
|
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
|
||||||
|
|
||||||
assert request.is_finished()
|
assert request.is_finished()
|
||||||
|
|
||||||
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
|
||||||
self.encoder_cache_manager.free(request)
|
self.encoder_cache_manager.free(request)
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
self._cached_reqs_data.pop(request_id, None)
|
|
||||||
self.finished_req_ids.add(request_id)
|
self.finished_req_ids.add(request_id)
|
||||||
if self.finished_req_ids_dict is not None:
|
if self.finished_req_ids_dict is not None:
|
||||||
self.finished_req_ids_dict[request.client_index].add(request_id)
|
self.finished_req_ids_dict[request.client_index].add(request_id)
|
||||||
@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
def _free_blocks(self, request: Request):
|
def _free_blocks(self, request: Request):
|
||||||
assert request.is_finished()
|
assert request.is_finished()
|
||||||
assert request.request_id not in self._cached_reqs_data
|
|
||||||
self.kv_cache_manager.free(request)
|
self.kv_cache_manager.free(request)
|
||||||
self.kv_cache_manager.free_block_hashes(request)
|
self.kv_cache_manager.free_block_hashes(request)
|
||||||
del self.requests[request.request_id]
|
del self.requests[request.request_id]
|
||||||
|
|||||||
@ -470,34 +470,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_ids_to_add.append(req_id)
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
# Update the states of the running/resumed requests.
|
# Update the states of the running/resumed requests.
|
||||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
req_data = scheduler_output.scheduled_cached_reqs
|
||||||
req_id = req_data.req_id
|
for i, req_id in enumerate(req_data.req_ids):
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||||
|
new_token_ids = req_data.new_token_ids[i]
|
||||||
|
new_block_ids = req_data.new_block_ids[i]
|
||||||
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||||
|
|
||||||
# Update the cached states.
|
# Update the cached states.
|
||||||
num_computed_tokens = req_data.num_computed_tokens
|
|
||||||
req_state.num_computed_tokens = num_computed_tokens
|
req_state.num_computed_tokens = num_computed_tokens
|
||||||
# Add the sampled token(s) from the previous step (if any).
|
# Add the sampled token(s) from the previous step (if any).
|
||||||
# This doesn't include "unverified" tokens like spec decode tokens.
|
# This doesn't include "unverified" tokens like spec decode tokens.
|
||||||
num_new_tokens = (num_computed_tokens +
|
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
||||||
len(req_data.new_token_ids) -
|
|
||||||
req_state.num_tokens)
|
req_state.num_tokens)
|
||||||
if num_new_tokens == 1:
|
if num_new_tokens == 1:
|
||||||
# Avoid slicing list in most common case.
|
# Avoid slicing list in most common case.
|
||||||
req_state.output_token_ids.append(req_data.new_token_ids[-1])
|
req_state.output_token_ids.append(new_token_ids[-1])
|
||||||
elif num_new_tokens > 0:
|
elif num_new_tokens > 0:
|
||||||
req_state.output_token_ids.extend(
|
req_state.output_token_ids.extend(
|
||||||
req_data.new_token_ids[-num_new_tokens:])
|
new_token_ids[-num_new_tokens:])
|
||||||
# Update the block IDs.
|
# Update the block IDs.
|
||||||
if not req_data.resumed_from_preemption:
|
if not resumed_from_preemption:
|
||||||
# Append the new blocks to the existing block IDs.
|
# Append the new blocks to the existing block IDs.
|
||||||
for block_ids, new_block_ids in zip(req_state.block_ids,
|
for block_ids, new_ids in zip(req_state.block_ids,
|
||||||
req_data.new_block_ids):
|
new_block_ids):
|
||||||
block_ids.extend(new_block_ids)
|
block_ids.extend(new_ids)
|
||||||
else:
|
else:
|
||||||
# The request is resumed from preemption.
|
# The request is resumed from preemption.
|
||||||
# Replace the existing block IDs with the new ones.
|
# Replace the existing block IDs with the new ones.
|
||||||
req_state.block_ids = req_data.new_block_ids
|
req_state.block_ids = new_block_ids
|
||||||
|
|
||||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||||
if req_index is None:
|
if req_index is None:
|
||||||
@ -510,14 +512,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Update the persistent batch.
|
# Update the persistent batch.
|
||||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||||
num_computed_tokens)
|
num_computed_tokens)
|
||||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
self.input_batch.block_table.append_row(new_block_ids, req_index)
|
||||||
req_index)
|
|
||||||
# Add new_token_ids to token_ids_cpu.
|
# Add new_token_ids to token_ids_cpu.
|
||||||
start_token_index = num_computed_tokens
|
start_token_index = num_computed_tokens
|
||||||
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
|
end_token_index = num_computed_tokens + len(new_token_ids)
|
||||||
self.input_batch.token_ids_cpu[
|
self.input_batch.token_ids_cpu[
|
||||||
req_index,
|
req_index, start_token_index:end_token_index] = new_token_ids
|
||||||
start_token_index:end_token_index] = req_data.new_token_ids
|
|
||||||
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
||||||
# Add spec_token_ids to token_ids_cpu.
|
# Add spec_token_ids to token_ids_cpu.
|
||||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||||
|
|||||||
@ -418,21 +418,24 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_ids_to_add.append(req_id)
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
# Update the states of the running/resumed requests.
|
# Update the states of the running/resumed requests.
|
||||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
req_data = scheduler_output.scheduled_cached_reqs
|
||||||
req_id = req_data.req_id
|
for i, req_id in enumerate(req_data.req_ids):
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||||
|
new_block_ids = req_data.new_block_ids[i]
|
||||||
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||||
|
|
||||||
# Update the cached states.
|
# Update the cached states.
|
||||||
req_state.num_computed_tokens = req_data.num_computed_tokens
|
req_state.num_computed_tokens = num_computed_tokens
|
||||||
if not req_data.resumed_from_preemption:
|
if not resumed_from_preemption:
|
||||||
# Append the new blocks to the existing block IDs.
|
# Append the new blocks to the existing block IDs.
|
||||||
for block_ids, new_block_ids in zip(req_state.block_ids,
|
for block_ids, new_ids in zip(req_state.block_ids,
|
||||||
req_data.new_block_ids):
|
new_block_ids):
|
||||||
block_ids.extend(new_block_ids)
|
block_ids.extend(new_ids)
|
||||||
else:
|
else:
|
||||||
# The request is resumed from preemption.
|
# The request is resumed from preemption.
|
||||||
# Replace the existing block IDs with the new ones.
|
# Replace the existing block IDs with the new ones.
|
||||||
req_state.block_ids = req_data.new_block_ids
|
req_state.block_ids = new_block_ids
|
||||||
|
|
||||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||||
if req_index is None:
|
if req_index is None:
|
||||||
@ -444,9 +447,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Update the persistent batch.
|
# Update the persistent batch.
|
||||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||||
req_data.num_computed_tokens)
|
num_computed_tokens)
|
||||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
self.input_batch.block_table.append_row(new_block_ids, req_index)
|
||||||
req_index)
|
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
# Add the new or resumed requests to the persistent batch.
|
||||||
# The smaller empty indices are filled first.
|
# The smaller empty indices are filled first.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user