mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 12:05:38 +08:00
[V1] Check all pooling tasks during profiling (#21299)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
2671334d45
commit
f59ec35b7f
@ -1173,6 +1173,10 @@ class PoolingSequenceGroupOutput(
|
|||||||
# The actual type is in SequenceGroup.pooled_data
|
# The actual type is in SequenceGroup.pooled_data
|
||||||
data: Any
|
data: Any
|
||||||
|
|
||||||
|
def get_data_nbytes(self) -> int:
|
||||||
|
data: torch.Tensor = self.data
|
||||||
|
return data.nbytes
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"PoolingSequenceGroupOutput(data={self.data}"
|
return f"PoolingSequenceGroupOutput(data={self.data}"
|
||||||
|
|
||||||
@ -1234,6 +1238,9 @@ class PoolerOutput(
|
|||||||
"""The output from a pooling operation in the pooling model."""
|
"""The output from a pooling operation in the pooling model."""
|
||||||
outputs: list[PoolingSequenceGroupOutput]
|
outputs: list[PoolingSequenceGroupOutput]
|
||||||
|
|
||||||
|
def get_data_nbytes(self) -> int:
|
||||||
|
return sum(o.get_data_nbytes() for o in self.outputs)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
||||||
return self.outputs[idx]
|
return self.outputs[idx]
|
||||||
|
|
||||||
|
|||||||
@ -41,7 +41,7 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
|||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available, round_up)
|
is_pin_memory_available, round_up)
|
||||||
@ -1819,7 +1819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
old_global_expert_indices = None
|
old_global_expert_indices = None
|
||||||
rank_mapping = None
|
rank_mapping = None
|
||||||
|
|
||||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
with DeviceMemoryProfiler() as m:
|
||||||
time_before_load = time.perf_counter()
|
time_before_load = time.perf_counter()
|
||||||
model_loader = get_model_loader(self.load_config)
|
model_loader = get_model_loader(self.load_config)
|
||||||
if not hasattr(self, "model"):
|
if not hasattr(self, "model"):
|
||||||
@ -2215,12 +2215,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
@torch.inference_mode()
|
def _dummy_pooler_run_task(
|
||||||
def _dummy_pooler_run(
|
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
task: PoolingTask,
|
||||||
|
) -> PoolerOutput:
|
||||||
num_tokens = hidden_states.shape[0]
|
num_tokens = hidden_states.shape[0]
|
||||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
num_reqs = min(num_tokens, max_num_reqs)
|
num_reqs = min(num_tokens, max_num_reqs)
|
||||||
@ -2232,37 +2231,55 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
hidden_states_list = list(
|
hidden_states_list = list(
|
||||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||||
|
|
||||||
req_num_tokens = num_tokens // num_reqs
|
req_num_tokens = num_tokens // num_reqs
|
||||||
|
|
||||||
model = cast(VllmModelForPooling, self.model)
|
dummy_prompt_lens = torch.tensor(
|
||||||
dummy_task = self.get_supported_pooling_tasks()[0]
|
[h.shape[0] for h in hidden_states_list],
|
||||||
dummy_pooling_params = PoolingParams(task=dummy_task)
|
device=self.device,
|
||||||
|
)
|
||||||
|
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
to_update = model.pooler.get_pooling_updates(dummy_task)
|
model = cast(VllmModelForPooling, self.model)
|
||||||
|
dummy_pooling_params = PoolingParams(task=task)
|
||||||
|
to_update = model.pooler.get_pooling_updates(task)
|
||||||
to_update.apply(dummy_pooling_params)
|
to_update.apply(dummy_pooling_params)
|
||||||
|
|
||||||
dummy_metadata = PoolingMetadata(
|
dummy_metadata = PoolingMetadata(
|
||||||
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
prompt_lens=dummy_prompt_lens,
|
||||||
device=self.device),
|
prompt_token_ids=dummy_token_ids,
|
||||||
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||||
dtype=torch.int32,
|
)
|
||||||
device=self.device),
|
|
||||||
pooling_params=[dummy_pooling_params] * num_reqs)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pooler_output = model.pooler(hidden_states=hidden_states_list,
|
return model.pooler(hidden_states=hidden_states_list,
|
||||||
pooling_metadata=dummy_metadata)
|
pooling_metadata=dummy_metadata)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if 'out of memory' in str(e):
|
if 'out of memory' in str(e):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"CUDA out of memory occurred when warming up pooler with "
|
"CUDA out of memory occurred when warming up pooler "
|
||||||
f"{num_reqs} dummy requests. Please try lowering "
|
f"({task=}) with {num_reqs} dummy requests. Please try "
|
||||||
"`max_num_seqs` or `gpu_memory_utilization` when "
|
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
||||||
"initializing the engine.") from e
|
"initializing the engine.") from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
return pooler_output
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def _dummy_pooler_run(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
# Find the task that has the largest output for subsequent steps
|
||||||
|
output_size = dict[PoolingTask, float]()
|
||||||
|
for task in self.get_supported_pooling_tasks():
|
||||||
|
# Run a full batch with each task to ensure none of them OOMs
|
||||||
|
output = self._dummy_pooler_run_task(hidden_states, task)
|
||||||
|
output_size[task] = output.get_data_nbytes()
|
||||||
|
del output # Allow GC
|
||||||
|
|
||||||
|
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
||||||
|
return self._dummy_pooler_run_task(hidden_states, max_task)
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# Profile with multimodal encoder & encoder cache.
|
# Profile with multimodal encoder & encoder cache.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user