mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:34:27 +08:00
[V1] Check all pooling tasks during profiling (#21299)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
2671334d45
commit
f59ec35b7f
@ -1173,6 +1173,10 @@ class PoolingSequenceGroupOutput(
|
||||
# The actual type is in SequenceGroup.pooled_data
|
||||
data: Any
|
||||
|
||||
def get_data_nbytes(self) -> int:
|
||||
data: torch.Tensor = self.data
|
||||
return data.nbytes
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"PoolingSequenceGroupOutput(data={self.data}"
|
||||
|
||||
@ -1234,6 +1238,9 @@ class PoolerOutput(
|
||||
"""The output from a pooling operation in the pooling model."""
|
||||
outputs: list[PoolingSequenceGroupOutput]
|
||||
|
||||
def get_data_nbytes(self) -> int:
|
||||
return sum(o.get_data_nbytes() for o in self.outputs)
|
||||
|
||||
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
||||
return self.outputs[idx]
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.pooling_params import PoolingParams, PoolingTask
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available, round_up)
|
||||
@ -1819,7 +1819,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
old_global_expert_indices = None
|
||||
rank_mapping = None
|
||||
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
with DeviceMemoryProfiler() as m:
|
||||
time_before_load = time.perf_counter()
|
||||
model_loader = get_model_loader(self.load_config)
|
||||
if not hasattr(self, "model"):
|
||||
@ -2215,12 +2215,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(
|
||||
def _dummy_pooler_run_task(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
task: PoolingTask,
|
||||
) -> PoolerOutput:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
@ -2232,37 +2231,55 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
hidden_states_list = list(
|
||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||
|
||||
req_num_tokens = num_tokens // num_reqs
|
||||
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
dummy_task = self.get_supported_pooling_tasks()[0]
|
||||
dummy_pooling_params = PoolingParams(task=dummy_task)
|
||||
dummy_prompt_lens = torch.tensor(
|
||||
[h.shape[0] for h in hidden_states_list],
|
||||
device=self.device,
|
||||
)
|
||||
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
to_update = model.pooler.get_pooling_updates(dummy_task)
|
||||
model = cast(VllmModelForPooling, self.model)
|
||||
dummy_pooling_params = PoolingParams(task=task)
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(dummy_pooling_params)
|
||||
|
||||
dummy_metadata = PoolingMetadata(
|
||||
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
||||
device=self.device),
|
||||
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device),
|
||||
pooling_params=[dummy_pooling_params] * num_reqs)
|
||||
prompt_lens=dummy_prompt_lens,
|
||||
prompt_token_ids=dummy_token_ids,
|
||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||
)
|
||||
|
||||
try:
|
||||
pooler_output = model.pooler(hidden_states=hidden_states_list,
|
||||
pooling_metadata=dummy_metadata)
|
||||
return model.pooler(hidden_states=hidden_states_list,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA out of memory occurred when warming up pooler with "
|
||||
f"{num_reqs} dummy requests. Please try lowering "
|
||||
"`max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"CUDA out of memory occurred when warming up pooler "
|
||||
f"({task=}) with {num_reqs} dummy requests. Please try "
|
||||
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
return pooler_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> PoolerOutput:
|
||||
# Find the task that has the largest output for subsequent steps
|
||||
output_size = dict[PoolingTask, float]()
|
||||
for task in self.get_supported_pooling_tasks():
|
||||
# Run a full batch with each task to ensure none of them OOMs
|
||||
output = self._dummy_pooler_run_task(hidden_states, task)
|
||||
output_size[task] = output.get_data_nbytes()
|
||||
del output # Allow GC
|
||||
|
||||
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
||||
return self._dummy_pooler_run_task(hidden_states, max_task)
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user