[Misc] Move fast prefill logic to separate method (#24013)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-31 22:40:38 -07:00 committed by GitHub
parent acc1a6e10a
commit b55713683c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -783,28 +783,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill:
assert self.kv_sharing_fast_prefill_logits_indices is not None
num_logits = logits_indices.shape[0]
assert num_logits > 0
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices)
# There might have leftover indices in logits_indices[num_logits:]
# from previous iterations, whose values may be greater than the
# batch size in the current iteration. To ensure indices are always
# valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item())
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(
num_logits)
else:
num_logits_padded = num_logits
logits_indices_padded = (
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
)
attn_metadata: dict[str, Any] = {}
@ -1109,6 +1089,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
return metadata
def _prepare_kv_sharing_fast_prefill(
self,
logits_indices: torch.Tensor,
) -> torch.Tensor:
assert self.kv_sharing_fast_prefill_logits_indices is not None
num_logits = logits_indices.shape[0]
assert num_logits > 0
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
logits_indices)
# There might have leftover indices in logits_indices[num_logits:]
# from previous iterations, whose values may be greater than the
# batch size in the current iteration. To ensure indices are always
# valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item())
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
else:
num_logits_padded = num_logits
logits_indices_padded = (
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
return logits_indices_padded
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: