mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 11:05:16 +08:00
[V1][VLM] Proper memory profiling for image language models (#11210)
Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: ywang96 <ywang@example.com>
This commit is contained in:
parent
66d4b16724
commit
59c9b6ebeb
@ -1280,6 +1280,14 @@ class SchedulerConfig:
|
||||
|
||||
is_multimodal_model: bool = False
|
||||
|
||||
# FIXME(woosuk & ywang96): Below are placeholder values. We need to
|
||||
# calculate the actual values from the configurations.
|
||||
# Multimodal encoder run compute budget, only used in V1
|
||||
max_num_encoder_input_tokens = 16384
|
||||
|
||||
# Multimodal encoder cache size, only used in V1
|
||||
encoder_cache_size = 16384
|
||||
|
||||
# Whether to perform preemption by swapping or
|
||||
# recomputation. If not specified, we determine the mode as follows:
|
||||
# We use recomputation by default since it incurs lower overhead than
|
||||
|
||||
@ -245,6 +245,11 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# Do not split, return as tensor of shape [1, fs, hs]
|
||||
return image_embeds.unsqueeze(0)
|
||||
|
||||
# If the last split index is the last index in image_tokens, we
|
||||
# ignore it to avoid empty split tensor
|
||||
if split_indices[-1] == len(image_tokens):
|
||||
split_indices = split_indices[:-1]
|
||||
|
||||
image_embeds = image_embeds.tensor_split(split_indices.cpu())
|
||||
return image_embeds
|
||||
|
||||
|
||||
@ -200,6 +200,23 @@ class MultiModalRegistry:
|
||||
"""
|
||||
return self.register_max_multimodal_tokens("image", max_mm_tokens)
|
||||
|
||||
def get_max_tokens_per_item_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality
|
||||
for profiling the memory usage of a model.
|
||||
|
||||
Note:
|
||||
This is currently directly used only in V1.
|
||||
"""
|
||||
|
||||
return {
|
||||
key: plugin.get_max_multimodal_tokens(model_config)
|
||||
for key, plugin in self._plugins.items()
|
||||
}
|
||||
|
||||
def get_max_tokens_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
@ -216,9 +233,9 @@ class MultiModalRegistry:
|
||||
limits_per_plugin = self._limits_by_model[model_config]
|
||||
|
||||
return {
|
||||
key: (limits_per_plugin[key] *
|
||||
plugin.get_max_multimodal_tokens(model_config))
|
||||
for key, plugin in self._plugins.items()
|
||||
key: limits_per_plugin[key] * max_tokens_per_mm_item
|
||||
for key, max_tokens_per_mm_item in
|
||||
self.get_max_tokens_per_item_by_modality(model_config).items()
|
||||
}
|
||||
|
||||
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
||||
|
||||
@ -73,14 +73,13 @@ class Scheduler:
|
||||
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
|
||||
# projector if needed). Currently, we assume that the encoder also
|
||||
# has the Transformer architecture (e.g., ViT).
|
||||
# FIXME(woosuk): Below are placeholder values. We need to calculate the
|
||||
# actual values from the configurations.
|
||||
self.max_num_encoder_input_tokens = 16384
|
||||
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens #noqa: E501
|
||||
# NOTE(woosuk): For the models without encoder (e.g., text-only models),
|
||||
# the encoder cache will not be initialized and used, regardless of
|
||||
# the cache size. This is because the memory space for the encoder cache
|
||||
# is preallocated in the profiling run.
|
||||
self.encoder_cache_manager = EncoderCacheManager(cache_size=16384)
|
||||
self.encoder_cache_manager = EncoderCacheManager(
|
||||
cache_size=self.scheduler_config.encoder_cache_size)
|
||||
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
|
||||
@ -54,6 +54,7 @@ class MMInputMapperClient:
|
||||
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
||||
self.mm_cache_hits / self.mm_cache_total)
|
||||
|
||||
# TODO: Support modalities beyond image.
|
||||
def process_inputs(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
|
||||
@ -10,15 +10,16 @@ import torch.nn as nn
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
|
||||
FlashAttentionMetadata)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -35,7 +36,6 @@ class GPUModelRunner:
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
@ -77,7 +77,12 @@ class GPUModelRunner:
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
self.input_registry = INPUT_REGISTRY
|
||||
self.mm_registry = MULTIMODAL_REGISTRY
|
||||
# NOTE: mm_input_mapper is only used for memory profiling.
|
||||
self.mm_input_mapper = MMInputMapperClient(self.model_config)
|
||||
self.max_num_encoder_input_tokens = self.scheduler_config.max_num_encoder_input_tokens # noqa: E501
|
||||
self.encoder_cache_size = self.scheduler_config.encoder_cache_size
|
||||
|
||||
# Lazy initialization
|
||||
# self.model: nn.Module # Set after load_model
|
||||
@ -599,8 +604,6 @@ class GPUModelRunner:
|
||||
return hidden_states
|
||||
|
||||
def profile_run(self) -> None:
|
||||
# TODO(woosuk): Profile the max memory usage of the encoder and
|
||||
# the encoder cache.
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value `None`.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
@ -612,6 +615,57 @@ class GPUModelRunner:
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
for _ in range(self.num_attn_layers)
|
||||
]
|
||||
|
||||
# Profile with multimodal encoder & encoder cache.
|
||||
# TODO (ywang96): generalize this beyond image modality since
|
||||
# mm_input_mapper only supports image inputs.
|
||||
if self.is_multimodal_model:
|
||||
|
||||
# Create dummy batch of multimodal inputs.
|
||||
dummy_request_data = self.input_registry.dummy_data_for_profiling(
|
||||
model_config=self.model_config,
|
||||
seq_len=self.max_num_tokens,
|
||||
mm_registry=self.mm_registry,
|
||||
)
|
||||
dummy_mm_data = dummy_request_data.multi_modal_data
|
||||
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
|
||||
mm_data=dummy_mm_data,
|
||||
mm_hashes=None,
|
||||
mm_processor_kwargs=None,
|
||||
precomputed_mm_inputs=None)
|
||||
|
||||
# NOTE: Currently model is profiled with a single non-text
|
||||
# modality even when it supports multiple.
|
||||
max_tokens_per_mm_item = max(
|
||||
self.mm_registry.get_max_tokens_per_item_by_modality(
|
||||
self.model_config).values())
|
||||
|
||||
max_num_mm_items = min(
|
||||
self.max_num_encoder_input_tokens,
|
||||
self.encoder_cache_size) // max_tokens_per_mm_item
|
||||
|
||||
# Dummy data definition in V0 may contain multiple multimodal items
|
||||
# (e.g, multiple images) for a single request, therefore here we
|
||||
# always replicate first item by max_num_mm_items times since in V1
|
||||
# they are scheduled to be processed separately.
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||
[dummy_mm_kwargs[0]] * max_num_mm_items)
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs, device=self.device)
|
||||
|
||||
# Run multimodal encoder.
|
||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||
**batched_dummy_mm_inputs)
|
||||
assert len(dummy_encoder_outputs) == max_num_mm_items, (
|
||||
"Expected dimension 0 of encoder outputs to match the number "
|
||||
f"of multimodal data items: {max_num_mm_items}, got "
|
||||
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
|
||||
"due to the 'get_multimodal_embeddings' method of the model "
|
||||
"not implemented correctly.")
|
||||
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
|
||||
# Trigger compilation for general shape.
|
||||
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
|
||||
dummy_kv_caches)
|
||||
@ -620,6 +674,7 @@ class GPUModelRunner:
|
||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, logits
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def capture_model(self) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user