From e107680d8a1621ec78871b6649e193c178ad2139 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 15 Sep 2025 21:19:18 +0000 Subject: [PATCH] wip Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/model_runner.py | 79 ++++++++++++++++++++++++------ vllm/v1/worker/gpu/states.py | 8 +++ 2 files changed, 73 insertions(+), 14 deletions(-) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 334c6f953ff2f..e570e4924deeb 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -10,6 +10,7 @@ from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_pin_memory_available +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.sampler import SamplerOutput @@ -178,7 +179,6 @@ class GPUModelRunner: positions = self.input_buffers.positions query_start_loc = self.input_buffers.query_start_loc seq_lens = self.input_buffers.seq_lens - prepare_inputs( idx_mapping_np, self.req_states.token_ids, @@ -196,30 +196,59 @@ class GPUModelRunner: # tensors from CPU to GPU, because they may include paddings needed # for full CUDA graph mode. query_start_loc.copy_to_gpu() - query_start_loc = query_start_loc.gpu[:num_reqs + 1] + query_start_loc_cpu = query_start_loc.cpu[:num_reqs + 1] + query_start_loc_gpu = query_start_loc.gpu[:num_reqs + 1] max_query_len = int(num_scheduled_tokens.max()) seq_lens.copy_to_gpu() + seq_lens_cpu = seq_lens.cpu[:num_reqs] seq_lens_np = seq_lens.np[:num_reqs] max_seq_len = int(seq_lens_np.max()) - seq_lens = seq_lens.gpu[:num_reqs] + seq_lens_gpu = seq_lens.gpu[:num_reqs] + num_computed_tokens_np = self.req_states.num_computed_tokens[ + idx_mapping_np] + num_computed_tokens_cpu = torch.from_numpy(num_computed_tokens_np) num_tokens = self.req_states.num_tokens[idx_mapping_np] is_chunked_prefilling = seq_lens_np < num_tokens # Slot mappings: [num_kv_cache_groups, num_tokens] slot_mappings = self.block_tables.compute_slot_mappings( - query_start_loc, positions.gpu[:num_tokens]) - - logits_indices = query_start_loc[1:] - 1 + query_start_loc_gpu, positions.gpu[:num_tokens]) + logits_indices = query_start_loc_gpu[1:] - 1 + # Layer name -> attention metadata. attn_metadata: dict[str, Any] = {} for i, kv_cache_spec in enumerate( self.kv_cache_config.kv_cache_groups): block_table = block_tables[i] slot_mapping = slot_mappings[i] - num_common_prefix_blocks = 0 + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc_gpu, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens_gpu, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + block_table_tensor=block_table, + slot_mapping=slot_mapping, + logits_indices_padded=None, + num_logits_indices=logits_indices.size(0), + causal=True, + encoder_seq_lens=None, + ) + + attn_metadata_builder = self.attn_metadata_builders[i] + attn_metadata = attn_metadata_builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + for layer_name in kv_cache_spec.layer_names: + attn_metadata[layer_name] = attn_metadata return InputBatch( req_ids=req_ids, @@ -237,8 +266,8 @@ class GPUModelRunner: def sample( self, - input_batch: InputBatch, logits: torch.Tensor, + input_batch: InputBatch, ) -> SamplerOutput: sampling_metadata = self.req_states.make_sampling_metadata( input_batch.idx_mapping_np) @@ -246,13 +275,31 @@ class GPUModelRunner: logits=logits, sampling_metadata=sampling_metadata, ) - return sampler_output + def compute_prompt_logprobs( + self, + hidden_states: torch.Tensor, + input_batch: InputBatch, + ): + idx_mapping_np = input_batch.idx_mapping_np + needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[ + idx_mapping_np] + if not np.any(needs_prompt_logprobs): + # Common case. + # No request in the batch needs prompt logprobs. + return None + + num_prompt_tokens_scheduled = ... + if not np.any(num_prompt_tokens_scheduled > 0 & needs_prompt_logprobs): + # The request already computed prompt logprobs. + return None + return + def postprocess( self, - input_batch: InputBatch, sampler_output: SamplerOutput, + input_batch: InputBatch, ) -> np.ndarray: # Get the number of sampled tokens. # 0 if chunked-prefilling, 1 if not. @@ -282,9 +329,13 @@ class GPUModelRunner: positions=input_batch.positions, ) - sampling_hidden_states = hidden_states[input_batch.logits_indices] - logits = self.model.compute_logits(sampling_hidden_states, None) - sampler_output = self.sample(input_batch, logits) + # Compute logits to sample next tokens. + sample_hidden_states = hidden_states[input_batch.logits_indices] + logits = self.model.compute_logits(sample_hidden_states, None) - num_sampled_tokens = self.postprocess(input_batch, sampler_output) + sampler_output = self.sample(logits, input_batch) + prompt_logprobs = self.compute_prompt_logprobs(hidden_states, + input_batch) + + output = self.postprocess(sampler_output, input_batch) return output diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index 63286dcf742c5..3c756bfef5668 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -51,6 +51,7 @@ class RequestState: ) self.num_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) self.num_computed_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(self.max_num_reqs, dtype=np.int32) # Last sampled token ids. self.last_token = torch.zeros( @@ -67,6 +68,8 @@ class RequestState: # -1 means no logprobs are requested. self.num_logprobs.fill(-1) + self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) + @property def num_reqs(self) -> int: return len(self.req_id_to_index) @@ -85,6 +88,7 @@ class RequestState: prompt_len = len(prompt_token_ids) self.num_tokens[req_idx] = prompt_len + self.num_prompt_tokens[req_idx] = prompt_len self.token_ids[req_idx, :prompt_len] = prompt_token_ids self.num_computed_tokens[req_idx] = num_computed_tokens @@ -102,6 +106,10 @@ class RequestState: num_logprobs = -1 self.num_logprobs[req_idx] = num_logprobs + # For now, only support prompt logprobs for the prompt tokens. + needs_prompt_logprobs = sampling_params.prompt_logprobs is not None + self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs + def append_token_ids( self, req_idx: int,