mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-27 05:58:42 +08:00
[V1][PP] Cache Intermediate Tensors (#13353)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
7b89386553
commit
e18227b04a
@ -2,7 +2,7 @@
|
||||
|
||||
import gc
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -149,6 +149,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
# self.intermediate_tensors # Set after load_model
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
@ -869,7 +870,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> ModelRunnerOutput:
|
||||
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||
batch_changed = self._update_states(scheduler_output)
|
||||
|
||||
if self.is_multimodal_model:
|
||||
@ -919,6 +920,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
positions = self.positions[:num_input_tokens]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_input_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
# Run the decoder.
|
||||
# Use persistent buffers for CUDA graphs.
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
@ -931,7 +940,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
# For mid-pipeline stages, return the hidden states.
|
||||
return hidden_states
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
@ -1118,12 +1129,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions = self.mrope_positions[:, :num_tokens]
|
||||
else:
|
||||
positions = self.positions[:num_tokens]
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
else:
|
||||
if not hasattr(self, "intermediate_tensors"):
|
||||
self.intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors(
|
||||
batch_size=self.max_num_tokens,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device))
|
||||
intermediate_tensors = IntermediateTensors({
|
||||
k: v[:num_tokens]
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user