[Optimization] Use Shared CachedRequestData Instance Across All Requests (#20232)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-06-30 09:07:50 -07:00 committed by GitHub
parent 2965c99c86
commit 2863befce3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 220 additions and 231 deletions

View File

@ -10,7 +10,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
@ -198,7 +198,7 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
@ -225,7 +225,7 @@ def test_schedule_multimodal_requests():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
@ -259,7 +259,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
assert scheduler.max_num_encoder_input_tokens == 1024
@ -295,7 +295,7 @@ def test_schedule_partial_requests():
output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 2
assert output.scheduled_cached_reqs.num_reqs == 2
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 1
assert output.num_scheduled_tokens[requests[1].request_id] == 700
@ -319,7 +319,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# We want to only see the 400 text tokens at the start scheduled
assert output.num_scheduled_tokens[requests[0].request_id] == 400
@ -342,7 +342,7 @@ def test_no_mm_input_chunking():
output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(output.scheduled_new_reqs) == 0
assert len(output.scheduled_cached_reqs) == 1
assert output.scheduled_cached_reqs.num_reqs == 1
assert len(output.finished_req_ids) == 0
assert output.num_scheduled_tokens[requests[0].request_id] == 800
@ -379,7 +379,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400.
@ -408,7 +408,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output1 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0
assert len(output1.scheduled_cached_reqs) == 3
assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
@ -430,7 +430,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0
assert len(output2.scheduled_cached_reqs) == 3
assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
@ -449,23 +449,24 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@ -501,23 +502,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@ -551,23 +554,25 @@ def test_stop_via_update_from_output():
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@ -603,7 +608,7 @@ def test_stop_via_update_from_output():
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
@ -1208,7 +1213,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0

View File

@ -66,7 +66,7 @@ def test_basic_lifecycle():
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
@ -81,7 +81,7 @@ def test_basic_lifecycle():
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()

View File

@ -36,7 +36,7 @@ def test_basic_lifecycle():
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
@ -158,7 +158,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 1
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b])
@ -169,7 +169,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
reqs=[request_local_a, request_local_b])
@ -177,14 +177,14 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
# STEP 4: KVs arrive.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert len(scheduler_output.scheduled_cached_reqs) == 2
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b],
@ -196,7 +196,7 @@ def test_interleaved_lifecycle():
assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1
assert len(scheduler_output.scheduled_cached_reqs) == 2
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
model_runner_output = create_model_runner_output(
[request_local_a, request_local_b, request_remote])

View File

@ -25,7 +25,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0

View File

@ -82,7 +82,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
@ -161,7 +161,7 @@ def test_update_states_request_finished(model_runner):
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
@ -191,7 +191,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
@ -209,16 +209,16 @@ def test_update_states_request_resumed(model_runner):
# resume req
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=([], ),
num_computed_tokens=0,
req_ids=[req_id],
resumed_from_preemption=[False],
new_token_ids=[[]],
new_block_ids=[([], )],
num_computed_tokens=[0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@ -249,7 +249,7 @@ def test_update_states_no_changes(model_runner):
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@ -284,7 +284,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},

View File

@ -133,7 +133,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
return SchedulerOutput(
scheduled_new_reqs=new_reqs,
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
@ -199,7 +199,7 @@ def test_update_states_request_finished(model_runner):
# finish req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
@ -231,7 +231,7 @@ def test_update_states_request_resumed(model_runner):
# unschedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
@ -249,16 +249,16 @@ def test_update_states_request_resumed(model_runner):
# resume req
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=([], ),
num_computed_tokens=0,
req_ids=[req_id],
resumed_from_preemption=[False],
new_token_ids=[[]],
new_block_ids=([[0]], ),
num_computed_tokens=[0],
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[cached_req_data],
scheduled_cached_reqs=cached_req_data,
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@ -339,7 +339,7 @@ def test_update_states_no_changes(model_runner):
# schedule req
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
@ -376,7 +376,7 @@ def test_update_states_request_unscheduled(model_runner):
# unschedule req_1
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},

View File

@ -371,45 +371,48 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_size=self._block_size)
self._requests_need_load.pop(new_req.req_id)
for cached_req in scheduler_output.scheduled_cached_reqs:
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
if self.is_producer:
num_scheduled_tokens = (
scheduler_output.num_scheduled_tokens)[cached_req.req_id]
num_tokens = (num_scheduled_tokens +
cached_req.num_computed_tokens)
assert cached_req.req_id in self.chunked_prefill
block_ids = cached_req.new_block_ids[0]
if not cached_req.resumed_from_preemption:
block_ids = (self.chunked_prefill[cached_req.req_id][0] +
block_ids)
prompt_token_ids = self.chunked_prefill[cached_req.req_id][1]
scheduler_output.num_scheduled_tokens)[req_id]
num_tokens = (num_scheduled_tokens + num_computed_tokens)
assert req_id in self.chunked_prefill
block_ids = new_block_ids[0]
if not resumed_from_preemption:
block_ids = (self.chunked_prefill[req_id][0] + block_ids)
prompt_token_ids = self.chunked_prefill[req_id][1]
# the request's prompt is chunked prefill again
if num_tokens < len(prompt_token_ids):
self.chunked_prefill[cached_req.req_id] = (
block_ids, prompt_token_ids)
self.chunked_prefill[req_id] = (block_ids,
prompt_token_ids)
continue
# the request's prompt is all prefilled finally
meta.add_request(request_id=cached_req.req_id,
meta.add_request(request_id=req_id,
token_ids=prompt_token_ids,
block_ids=block_ids,
block_size=self._block_size)
self.chunked_prefill.pop(cached_req.req_id, None)
self.chunked_prefill.pop(req_id, None)
continue
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption:
if not resumed_from_preemption:
break
if cached_req.req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(cached_req.req_id)
total_tokens = cached_req.num_computed_tokens + 1
if req_id in self._requests_need_load:
request, _ = self._requests_need_load.pop(req_id)
total_tokens = num_computed_tokens + 1
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = cached_req.new_block_ids[0]
block_ids = new_block_ids[0]
meta.add_request(request_id=cached_req.req_id,
meta.add_request(request_id=req_id,
token_ids=token_ids,
block_ids=block_ids,
block_size=self._block_size)

View File

@ -304,23 +304,28 @@ class SharedStorageConnector(KVConnectorBase_V1):
block_size=self._block_size,
is_store=True)
for cached_req in scheduler_output.scheduled_cached_reqs:
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_token_ids = cached_reqs.new_token_ids[i]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
# NOTE(rob): here we rely on the resumed requests being
# the first N requests in the list scheduled_cache_reqs.
if not cached_req.resumed_from_preemption:
if not resumed_from_preemption:
break
if cached_req.req_id in self._requests_need_load:
if req_id in self._requests_need_load:
# NOTE(rob): cached_req_data does not have the full
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[cached_req.req_id]
total_tokens = (len(cached_req.new_token_ids) +
cached_req.num_computed_tokens)
request = self._requests_need_load[req_id]
total_tokens = (len(new_token_ids) + num_computed_tokens)
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request.
block_ids = cached_req.new_block_ids[0]
block_ids = new_block_ids[0]
meta.add_request(token_ids=token_ids,
block_ids=block_ids,

View File

@ -83,29 +83,27 @@ class NewRequestData:
@dataclass
class CachedRequestData:
req_id: str
req_ids: list[str]
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: list[int]
new_block_ids: tuple[list[int], ...]
num_computed_tokens: int
resumed_from_preemption: list[bool]
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int]
@property
def num_reqs(self) -> int:
return len(self.req_ids)
@classmethod
def from_request(
cls,
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: tuple[list[int], ...],
) -> CachedRequestData:
def make_empty(cls) -> CachedRequestData:
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=request.num_computed_tokens,
req_ids=[],
resumed_from_preemption=[],
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
)
@ -119,7 +117,7 @@ class SchedulerOutput:
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: list[CachedRequestData]
scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.

View File

@ -3,8 +3,9 @@
from __future__ import annotations
import itertools
import time
from collections import defaultdict, deque
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Optional, Union
@ -117,12 +118,6 @@ class Scheduler(SchedulerInterface):
# KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set()
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
# Request id -> deque of CachedRequestData
self._cached_reqs_data: dict[
str, deque[CachedRequestData]] = defaultdict(deque)
# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
@ -547,27 +542,16 @@ class Scheduler(SchedulerInterface):
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
resumed_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
@ -613,34 +597,39 @@ class Scheduler(SchedulerInterface):
def _make_cached_request_data(
self,
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: tuple[list[int], ...],
resumed_from_preemption: bool,
running_reqs: list[Request],
resumed_reqs: list[Request],
num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
num_computed_tokens = request.num_computed_tokens
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = []
num_computed_tokens: list[int] = []
req_data_queue = self._cached_reqs_data.get(request.request_id)
if req_data_queue:
req_data = req_data_queue.popleft()
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
# No cached request data, or all cached request data has been
# used by the scheduled requests.
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_token_ids,
new_block_ids)
return req_data
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)
return CachedRequestData(
req_ids=req_ids,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
)
def _try_schedule_encoder_inputs(
self,
@ -870,19 +859,11 @@ class Scheduler(SchedulerInterface):
if not stopped:
new_running.append(request)
self.running = new_running
# KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)
# Return the cached request data to the queue so they can be reused.
for req_data in scheduler_output.scheduled_cached_reqs:
# NOTE(rob): since we free stopped reqs above, adding stopped reqs
# to _cached_reqs_data will cause a memory leak.
if req_data.req_id not in self.finished_req_ids:
self._cached_reqs_data[req_data.req_id].append(req_data)
self.running = new_running
# Create EngineCoreOutputs for all clients that have requests with
# outputs in this step.
engine_core_outputs = {
@ -965,13 +946,11 @@ class Scheduler(SchedulerInterface):
self._free_request(request)
def _free_request(self, request: Request) -> Optional[dict[str, Any]]:
assert request.is_finished()
delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
request_id = request.request_id
self._cached_reqs_data.pop(request_id, None)
self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)
@ -983,7 +962,6 @@ class Scheduler(SchedulerInterface):
def _free_blocks(self, request: Request):
assert request.is_finished()
assert request.request_id not in self._cached_reqs_data
self.kv_cache_manager.free(request)
self.kv_cache_manager.free_block_hashes(request)
del self.requests[request.request_id]

View File

@ -470,34 +470,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs:
req_id = req_data.req_id
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_token_ids = req_data.new_token_ids[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
num_computed_tokens = req_data.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens +
len(req_data.new_token_ids) -
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(req_data.new_token_ids[-1])
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:])
new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not req_data.resumed_from_preemption:
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip(req_state.block_ids,
req_data.new_block_ids):
block_ids.extend(new_block_ids)
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
@ -510,14 +512,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
self.input_batch.block_table.append_row(new_block_ids, req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = req_data.new_token_ids
req_index, start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(

View File

@ -418,21 +418,24 @@ class TPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs:
req_id = req_data.req_id
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
req_state.num_computed_tokens = req_data.num_computed_tokens
if not req_data.resumed_from_preemption:
req_state.num_computed_tokens = num_computed_tokens
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip(req_state.block_ids,
req_data.new_block_ids):
block_ids.extend(new_block_ids)
for block_ids, new_ids in zip(req_state.block_ids,
new_block_ids):
block_ids.extend(new_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
@ -444,9 +447,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.