diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4a6856bf4fef..08e13ab887bf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -783,28 +783,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): logits_indices_padded = None if self.cache_config.kv_sharing_fast_prefill: - assert self.kv_sharing_fast_prefill_logits_indices is not None - num_logits = logits_indices.shape[0] - assert num_logits > 0 - self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices_padded = self._prepare_kv_sharing_fast_prefill( logits_indices) - # There might have leftover indices in logits_indices[num_logits:] - # from previous iterations, whose values may be greater than the - # batch size in the current iteration. To ensure indices are always - # valid, we fill the padded indices with the last index. - self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( - logits_indices[-1].item()) - if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1]): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph( - num_logits) - else: - num_logits_padded = num_logits - logits_indices_padded = ( - self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded] - ) attn_metadata: dict[str, Any] = {} @@ -1109,6 +1089,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return metadata + def _prepare_kv_sharing_fast_prefill( + self, + logits_indices: torch.Tensor, + ) -> torch.Tensor: + assert self.kv_sharing_fast_prefill_logits_indices is not None + num_logits = logits_indices.shape[0] + assert num_logits > 0 + self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_( + logits_indices) + # There might have leftover indices in logits_indices[num_logits:] + # from previous iterations, whose values may be greater than the + # batch size in the current iteration. To ensure indices are always + # valid, we fill the padded indices with the last index. + self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( + logits_indices[-1].item()) + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and num_logits <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) + else: + num_logits_padded = num_logits + logits_indices_padded = ( + self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]) + return logits_indices_padded + def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: