mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-14 09:57:02 +08:00
chunked prefilling
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
This commit is contained in:
parent
67852c1036
commit
69b17891a3
@ -54,6 +54,8 @@ class InputBatch:
|
||||
num_scheduled_tokens: np.ndarray
|
||||
# sum(num_scheduled_tokens)
|
||||
num_tokens: int
|
||||
# [num_reqs]
|
||||
is_chunked_prefilling: np.ndarray
|
||||
|
||||
# [max_num_batched_tokens]
|
||||
input_ids: torch.Tensor
|
||||
|
||||
@ -203,6 +203,9 @@ class GPUModelRunner:
|
||||
max_seq_len = int(seq_lens_np.max())
|
||||
seq_lens = seq_lens.gpu[:num_reqs]
|
||||
|
||||
num_tokens = self.req_states.num_tokens[idx_mapping_np]
|
||||
is_chunked_prefilling = seq_lens_np < num_tokens
|
||||
|
||||
# Slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc, positions.gpu[:num_tokens])
|
||||
@ -222,6 +225,7 @@ class GPUModelRunner:
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
is_chunked_prefilling=is_chunked_prefilling,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -248,18 +252,11 @@ class GPUModelRunner:
|
||||
sampler_output: SamplerOutput,
|
||||
) -> np.ndarray:
|
||||
# Get the number of sampled tokens.
|
||||
# Handle requests that are chunked-prefilling.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
num_computed_tokens = self.req_states.num_computed_tokens[
|
||||
idx_mapping_np]
|
||||
post_num_computed_tokens = (num_computed_tokens +
|
||||
input_batch.num_scheduled_tokens)
|
||||
num_tokens = self.req_states.num_tokens[idx_mapping_np]
|
||||
|
||||
is_chunked_prefilling = post_num_computed_tokens < num_tokens
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
is_chunked_prefilling = input_batch.is_chunked_prefilling
|
||||
num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32)
|
||||
# Increment the number of tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens
|
||||
return num_sampled_tokens
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user