[V1][PP] Cache Intermediate Tensors (#13353)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-02-16 10:02:27 -08:00 committed by GitHub
parent 7b89386553
commit e18227b04a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
import gc
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
import numpy as np
import torch
@ -149,6 +149,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
# self.intermediate_tensors # Set after load_model
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
if self.uses_mrope:
@ -869,7 +870,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> ModelRunnerOutput:
) -> Union[ModelRunnerOutput, torch.Tensor]:
batch_changed = self._update_states(scheduler_output)
if self.is_multimodal_model:
@ -919,6 +920,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
positions = self.positions[:num_input_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
intermediate_tensors = IntermediateTensors({
k: v[:num_input_tokens]
for k, v in self.intermediate_tensors.items()
})
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
@ -931,7 +940,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds,
)
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
return hidden_states
hidden_states = hidden_states[:num_scheduled_tokens]
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
@ -1118,12 +1129,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions = self.mrope_positions[:, :num_tokens]
else:
positions = self.positions[:num_tokens]
intermediate_tensors = None
if not get_pp_group().is_first_rank:
intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=num_tokens,
dtype=self.model_config.dtype,
device=self.device)
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
if not hasattr(self, "intermediate_tensors"):
self.intermediate_tensors = (
self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
for k, v in self.intermediate_tensors.items()
})
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=input_ids,