chunked prefilling

Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
Woosuk Kwon 2025-09-15 19:41:17 +00:00
parent 67852c1036
commit 69b17891a3
2 changed files with 8 additions and 9 deletions

View File

@ -54,6 +54,8 @@ class InputBatch:
num_scheduled_tokens: np.ndarray
# sum(num_scheduled_tokens)
num_tokens: int
# [num_reqs]
is_chunked_prefilling: np.ndarray
# [max_num_batched_tokens]
input_ids: torch.Tensor

View File

@ -203,6 +203,9 @@ class GPUModelRunner:
max_seq_len = int(seq_lens_np.max())
seq_lens = seq_lens.gpu[:num_reqs]
num_tokens = self.req_states.num_tokens[idx_mapping_np]
is_chunked_prefilling = seq_lens_np < num_tokens
# Slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
query_start_loc, positions.gpu[:num_tokens])
@ -222,6 +225,7 @@ class GPUModelRunner:
idx_mapping_np=idx_mapping_np,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
is_chunked_prefilling=is_chunked_prefilling,
input_ids=input_ids,
positions=positions,
attn_metadata=attn_metadata,
@ -248,18 +252,11 @@ class GPUModelRunner:
sampler_output: SamplerOutput,
) -> np.ndarray:
# Get the number of sampled tokens.
# Handle requests that are chunked-prefilling.
idx_mapping_np = input_batch.idx_mapping_np
num_computed_tokens = self.req_states.num_computed_tokens[
idx_mapping_np]
post_num_computed_tokens = (num_computed_tokens +
input_batch.num_scheduled_tokens)
num_tokens = self.req_states.num_tokens[idx_mapping_np]
is_chunked_prefilling = post_num_computed_tokens < num_tokens
# 0 if chunked-prefilling, 1 if not.
is_chunked_prefilling = input_batch.is_chunked_prefilling
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
# Increment the number of tokens.
idx_mapping_np = input_batch.idx_mapping_np
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
return num_sampled_tokens