[V1] Get input tokens from scheduler (#13339)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-02-17 11:01:07 -08:00 committed by GitHub
parent ce77eb9410
commit 4c21ce9eba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 139 additions and 139 deletions

View File

@ -154,6 +154,7 @@ def test_update_states_request_resumed(model_runner):
cached_req_data = CachedRequestData(
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=0,
)

View File

@ -121,6 +121,8 @@ class Scheduler:
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests.
@ -187,6 +189,15 @@ class Scheduler:
token_budget -= num_new_tokens
req_index += 1
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids[:num_scheduled_spec_tokens])
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
@ -196,11 +207,6 @@ class Scheduler:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Speculative decode related.
if request.spec_token_ids:
scheduled_spec_decode_tokens[
request.request_id] = request.spec_token_ids
# Record the LoRAs in scheduled_running_reqs
requested_loras: Set[int] = set()
if self.lora_config:
@ -324,23 +330,24 @@ class Scheduler:
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
req_to_new_block_ids[req.request_id],
req.num_computed_tokens)
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
resumed_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
req.num_computed_tokens,
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
req.num_computed_tokens,
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
@ -349,8 +356,8 @@ class Scheduler:
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
@ -366,22 +373,28 @@ class Scheduler:
def _make_cached_request_data(
self,
request: Request,
num_scheduled_tokens: int,
num_scheduled_spec_tokens: int,
new_block_ids: List[int],
num_computed_tokens: int,
resumed_from_preemption: bool,
) -> "CachedRequestData":
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
if request.request_id in self._cached_reqs_data:
req_data = self._cached_reqs_data[request.request_id]
num_computed_tokens = request.num_computed_tokens
num_regular_tokens = num_scheduled_tokens - num_scheduled_spec_tokens
new_token_ids = request.all_token_ids[
num_computed_tokens:num_computed_tokens + num_regular_tokens]
req_data = self._cached_reqs_data.get(request.request_id)
if req_data is not None:
req_data.resumed_from_preemption = resumed_from_preemption
req_data.new_token_ids = new_token_ids
req_data.new_block_ids = new_block_ids
req_data.num_computed_tokens = num_computed_tokens
else:
req_data = CachedRequestData.from_request(request,
resumed_from_preemption,
new_block_ids,
num_computed_tokens)
new_token_ids,
new_block_ids)
self._cached_reqs_data[request.request_id] = req_data
return req_data

View File

@ -30,7 +30,6 @@ class NewRequestData:
cls,
request: "Request",
block_ids: List[int],
num_computed_tokens: int,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
@ -41,7 +40,7 @@ class NewRequestData:
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=num_computed_tokens,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
)
@ -54,6 +53,7 @@ class CachedRequestData:
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: List[int]
new_block_ids: List[int]
num_computed_tokens: int
@ -62,14 +62,15 @@ class CachedRequestData:
cls,
request: "Request",
resumed_from_preemption: bool,
new_token_ids: List[int],
new_block_ids: List[int],
num_computed_tokens: int,
) -> "CachedRequestData":
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_computed_tokens=request.num_computed_tokens,
)
@ -91,9 +92,9 @@ class SchedulerOutput:
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_decode_tokens
# If a request does not have any spec decode tokens, it will
# not be included in the dictionary.
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: Dict[str, List[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs

View File

@ -2,7 +2,7 @@
import gc
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -184,7 +184,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_model_len,
self.max_num_tokens),
dtype=np.int32)
self.arange_cpu = torch.from_numpy(self.arange_np)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
@ -327,7 +326,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
req_state = self.requests[req_id]
# Update the cached states.
req_state.num_computed_tokens = req_data.num_computed_tokens
num_computed_tokens = req_data.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens +
len(req_data.new_token_ids) -
req_state.num_tokens)
new_token_ids = (req_data.new_token_ids[-num_new_tokens:]
if num_new_tokens > 0 else [])
req_state.output_token_ids.extend(new_token_ids)
# Update the block IDs.
if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs.
req_state.block_ids.extend(req_data.new_block_ids)
@ -346,12 +355,30 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
start_index = len(req_state.block_ids) - len(
req_data.new_block_ids)
num_computed_tokens)
start_index = (len(req_state.block_ids) -
len(req_data.new_block_ids))
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = req_data.new_token_ids
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, [])
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
@ -374,7 +401,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return batch_changed
def _prepare_inputs(
self, scheduler_output: "SchedulerOutput"
self,
scheduler_output: "SchedulerOutput",
) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@ -387,24 +415,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens_list: List[int] = []
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
max_num_scheduled_tokens = 0
all_spec_token_ids: List[int] = []
num_spec_tokens_list: List[int] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens_list.append(num_tokens)
num_scheduled_tokens[i] = num_tokens
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, [])
all_spec_token_ids.extend(spec_token_ids)
num_spec_tokens_list.append(len(spec_token_ids))
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
assert max_num_scheduled_tokens > 0
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
@ -441,78 +459,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
use_spec_decode = len(all_spec_token_ids) > 0
if use_spec_decode:
# 1. Write spec_token_ids to input batch.
# Step 1. Get req indices that perform spec decode and repeat
# the req indices by the number of spec tokens. Note
# for requests that don't perform spec decode, the
# number of spec tokens is 0 and the req index is
# repeated 0 times.
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
# spec_req_indices: [0, 0, 0, 2, 2, 4]
spec_req_indices = np.repeat(self.arange_np[:num_reqs],
num_spec_tokens_list)
# spec_offsets: offsets within each spec token list.
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
spec_offsets = np.concatenate(
[self.arange_np[1:val + 1] for val in num_spec_tokens_list])
# spec_seq_offsets: offsets within each sequence.
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
# after repeating: [1, 1, 1, 3, 3, 2]
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
# = [2, 3, 4, 4, 5, 3]
spec_seq_offsets = np.repeat(
self.input_batch.num_computed_tokens_cpu[:num_reqs],
num_spec_tokens_list) + spec_offsets
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
cumsums_spec_offsets = (
spec_seq_offsets +
spec_req_indices * self.input_batch.token_ids_cpu.shape[1])
cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to(
torch.int64)
all_spec_token_ids = torch.tensor(all_spec_token_ids,
device="cpu",
dtype=self.input_ids_cpu.dtype)
# Step 2. Write spec token ids to input_ids_cpu.
self.input_batch.token_ids_cpu_tensor.flatten().scatter_(
0, cumsums_spec_offsets, all_spec_token_ids)
# 2. Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32)
num_sampled_tokens = num_spec_tokens_np + 1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc = cu_num_tokens - num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens = num_sampled_tokens.sum()
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
cumsums_sampled_offsets)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices = logits_start_loc + sampled_arange
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
@ -606,9 +552,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
suffix_kv_lens=suffix_kv_lens,
)
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
if use_spec_decode:
logits_indices = torch.from_numpy(spec_decode_logits_indices).to(
self.device, non_blocking=True)
logits_indices = self._calc_spec_decode_metadata(
scheduler_output, cu_num_tokens)
else:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
@ -762,6 +710,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mrope_pos_ptr += completion_part_len
def _calc_spec_decode_metadata(
self,
scheduler_output: "SchedulerOutput",
cu_num_tokens: np.ndarray,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Get the number of spec decode tokens for each request.
num_reqs = self.input_batch.num_reqs
num_spec_decode_tokens = np.empty(num_reqs, dtype=np.int32)
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
num_spec_decode_tokens[i] = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
# Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_sampled_tokens = num_spec_decode_tokens + 1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc = cu_num_tokens - num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens = num_sampled_tokens.sum()
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
cumsums_sampled_offsets)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices = logits_start_loc + sampled_arange
return torch.from_numpy(spec_decode_logits_indices).to(
self.device, non_blocking=True)
def _prepare_sampling(
self,
batch_changed: bool,
@ -773,7 +768,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for req_id, req in self.requests.items()}
sampling_metadata = self.input_batch.make_sampling_metadata(
req_id_output_token_ids, req_to_spec_token_ids, not batch_changed)
req_id_output_token_ids,
req_to_spec_token_ids,
skip_copy=not batch_changed)
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
@ -960,28 +957,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
req_ids: List[str] = []
# Because `input_batch.req_ids` is a list of length `max_num_reqs`,
# we need to stop at `num_reqs`.
# FIXME(woosuk): This is hacky. Refactor.
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_ids.append(req_id)
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len >= req_state.num_tokens:
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
if seq_len < req_state.num_tokens:
# Ignore the sampled token.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
@ -994,29 +987,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output,
)
# Update batch with the valid generated tokens.
# Get the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens[i] += 1
else:
# Includes spec decode tokens.
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
# TODO(woosuk): Optimize this.
valid_sampled_token_ids = [
seq.tolist()
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]
self.input_batch.num_tokens[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens:
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
self.input_batch.token_ids_cpu[
i, target_slice] = valid_sampled_token_ids[i]
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,