[Optimization] Cache sampled token ids in model runner (#20291)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-07-01 11:01:31 -07:00 committed by GitHub
parent 02cabff207
commit 7f280d69c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 91 additions and 45 deletions

View File

@ -172,7 +172,7 @@ def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_state.block_ids[0]).all()
def test_update_states_new_request(model_runner):
def test_update_states_new_request(model_runner, dist_init):
req_id = "req_0"
# new req
@ -186,7 +186,7 @@ def test_update_states_new_request(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_finished(model_runner):
def test_update_states_request_finished(model_runner, dist_init):
req_id = "req_0"
# new req
@ -218,7 +218,7 @@ def test_update_states_request_finished(model_runner):
assert not _is_req_scheduled(model_runner, req_id)
def test_update_states_request_resumed(model_runner):
def test_update_states_request_resumed(model_runner, dist_init):
req_id = "req_0"
# new req
@ -278,7 +278,7 @@ def test_update_states_request_resumed(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_get_nans_in_logits(model_runner):
def test_get_nans_in_logits(model_runner, dist_init):
req_ids = ("req_0", "req_1")
scheduler_output = _schedule_new_request(*req_ids)
@ -326,7 +326,7 @@ def test_get_nans_in_logits(model_runner):
assert result == {'req_0': 2, 'req_1': 0}
def test_update_states_no_changes(model_runner):
def test_update_states_no_changes(model_runner, dist_init):
req_id = "req_0"
# new req
@ -359,7 +359,7 @@ def test_update_states_no_changes(model_runner):
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_unscheduled(model_runner):
def test_update_states_request_unscheduled(model_runner, dist_init):
req_ids = ("req_0", "req_1")
# new reqs

View File

@ -307,7 +307,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids):
num_computed_tokens = cached_reqs.num_computed_tokens[i]
new_token_ids = cached_reqs.new_token_ids[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
@ -320,7 +320,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = (len(new_token_ids) + num_computed_tokens)
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all

View File

@ -88,6 +88,8 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: list[bool]
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int]

View File

@ -55,6 +55,7 @@ class Scheduler(SchedulerInterface):
self.lora_config = vllm_config.lora_config
self.kv_cache_config = kv_cache_config
self.kv_events_config = vllm_config.kv_events_config
self.parallel_config = vllm_config.parallel_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
@ -87,7 +88,7 @@ class Scheduler(SchedulerInterface):
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
vllm_config.parallel_config.data_parallel_rank,
self.parallel_config.data_parallel_rank,
)
num_gpu_blocks = self.cache_config.num_gpu_blocks
@ -159,6 +160,7 @@ class Scheduler(SchedulerInterface):
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
@ -214,7 +216,7 @@ class Scheduler(SchedulerInterface):
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
@ -624,9 +626,15 @@ class Scheduler(SchedulerInterface):
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do

View File

@ -470,26 +470,33 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_token_ids = req_data.new_token_ids[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
@ -513,22 +520,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(new_block_ids, req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
# For the last rank, we don't need to update the token_ids_cpu
# because the sampled tokens are already cached.
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
req_index,
start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ()))
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] = end_token_index
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
@ -1509,6 +1524,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
# Cache the sampled tokens in the model runner, so that the scheduler
# doesn't need to send them back.
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
for req_idx, sampled_ids in enumerate(valid_sampled_token_ids):
if not sampled_ids:
continue
start_idx = self.input_batch.num_tokens_no_spec[req_idx]
end_idx = start_idx + len(sampled_ids)
assert end_idx <= self.max_model_len, (
"Sampled token IDs exceed the max model length. "
f"Total number of tokens: {end_idx} > max_model_len: "
f"{self.max_model_len}")
self.input_batch.token_ids_cpu[req_idx,
start_idx:end_idx] = sampled_ids
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
self.input_batch.num_tokens[req_idx] = end_idx
req_id = self.input_batch.req_ids[req_idx]
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
@ -1730,17 +1769,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids
if end_idx >= self.max_model_len:
num_tokens = self.input_batch.num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
draft_token_ids.append([])
continue
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
drafter_output = self.drafter.propose(
self.input_batch.token_ids_cpu[i, :end_idx])
self.input_batch.token_ids_cpu[i, :num_tokens])
if drafter_output is None or len(drafter_output) == 0:
draft_token_ids.append([])
else: