diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 63fdb02c1155b..65bafb338923f 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -54,6 +54,8 @@ class InputBatch: num_scheduled_tokens: np.ndarray # sum(num_scheduled_tokens) num_tokens: int + # [num_reqs] + is_chunked_prefilling: np.ndarray # [max_num_batched_tokens] input_ids: torch.Tensor diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 7c29a4b392398..79fdd91ed6b5d 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -203,6 +203,9 @@ class GPUModelRunner: max_seq_len = int(seq_lens_np.max()) seq_lens = seq_lens.gpu[:num_reqs] + num_tokens = self.req_states.num_tokens[idx_mapping_np] + is_chunked_prefilling = seq_lens_np < num_tokens + # Slot mappings: [num_kv_cache_groups, num_tokens] slot_mappings = self.block_tables.compute_slot_mappings( query_start_loc, positions.gpu[:num_tokens]) @@ -222,6 +225,7 @@ class GPUModelRunner: idx_mapping_np=idx_mapping_np, num_scheduled_tokens=num_scheduled_tokens, num_tokens=num_tokens, + is_chunked_prefilling=is_chunked_prefilling, input_ids=input_ids, positions=positions, attn_metadata=attn_metadata, @@ -248,18 +252,11 @@ class GPUModelRunner: sampler_output: SamplerOutput, ) -> np.ndarray: # Get the number of sampled tokens. - # Handle requests that are chunked-prefilling. - idx_mapping_np = input_batch.idx_mapping_np - num_computed_tokens = self.req_states.num_computed_tokens[ - idx_mapping_np] - post_num_computed_tokens = (num_computed_tokens + - input_batch.num_scheduled_tokens) - num_tokens = self.req_states.num_tokens[idx_mapping_np] - - is_chunked_prefilling = post_num_computed_tokens < num_tokens # 0 if chunked-prefilling, 1 if not. + is_chunked_prefilling = input_batch.is_chunked_prefilling num_sampled_tokens = (~is_chunked_prefilling).astype(np.int32) # Increment the number of tokens. + idx_mapping_np = input_batch.idx_mapping_np self.req_states.num_tokens[idx_mapping_np] += num_sampled_tokens return num_sampled_tokens