[Core] Factor out common logic for MM budget calculation (#22228)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-08-05 14:54:55 +08:00 committed by GitHub
parent e79a12fc3a
commit 811ac13d03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 299 additions and 216 deletions

View File

@ -36,7 +36,8 @@ from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model) VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
PlaceholderRange)
from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
@ -51,7 +52,6 @@ from vllm.v1.attention.backends.utils import (
make_kv_sharing_fast_prefill_attention_metadata, make_kv_sharing_fast_prefill_attention_metadata,
make_local_attention_virtual_batches, make_local_attention_virtual_batches,
reorder_batch_to_split_decodes_and_prefills) reorder_batch_to_split_decodes_and_prefills)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec,
ChunkedLocalAttentionSpec, ChunkedLocalAttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
@ -73,7 +73,7 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from ..sample.logits_processor import LogitsProcessorManager from ..sample.logits_processor import LogitsProcessorManager
from .utils import (bind_kv_cache, gather_mm_placeholders, from .utils import (MultiModalBudget, bind_kv_cache, gather_mm_placeholders,
initialize_kv_cache_for_kv_sharing, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders) sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
@ -148,14 +148,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope self.uses_mrope = model_config.uses_mrope
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
mm_registry=self.mm_registry,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Sampler # Sampler
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
@ -330,6 +322,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices = torch.zeros( self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device) self.max_num_tokens, dtype=torch.int32, device=self.device)
self.mm_budget = (MultiModalBudget(
self.model_config,
self.scheduler_config,
self.mm_registry,
max_model_len=self.max_model_len,
max_num_reqs=self.max_num_reqs,
) if self.is_multimodal_model else None)
self.reorder_batch_threshold: Optional[int] = None self.reorder_batch_threshold: Optional[int] = None
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
@ -578,37 +578,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Refresh batch metadata with any pending updates. # Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata() self.input_batch.refresh_metadata()
def _init_model_kwargs_for_multimodal_model( def _extract_mm_kwargs(
self, self,
scheduler_output: Optional["SchedulerOutput"] = None, scheduler_output: "SchedulerOutput",
num_reqs: int = -1, ) -> BatchedTensorInputs:
) -> dict[str, Any]: if self.is_multimodal_raw_input_supported: # noqa: SIM102
model_kwargs: dict[str, Any] = {}
if self.is_multimodal_raw_input_supported:
# This model requires the raw multimodal data in input.
if scheduler_output: if scheduler_output:
multi_modal_kwargs_list = [] multi_modal_kwargs_list = list[MultiModalKwargs]()
for req in scheduler_output.scheduled_new_reqs: for req in scheduler_output.scheduled_new_reqs:
req_mm_inputs = req.mm_inputs req_mm_inputs = req.mm_inputs
if not isinstance(req_mm_inputs, list): if not isinstance(req_mm_inputs, list):
req_mm_inputs = list(req_mm_inputs) req_mm_inputs = list(req_mm_inputs)
multi_modal_kwargs_list.extend(req_mm_inputs) multi_modal_kwargs_list.extend(req_mm_inputs)
multi_modal_kwargs = MultiModalKwargs.batch(
multi_modal_kwargs_list)
else:
# The only case where SchedulerOutput is None is for
# a dummy run let's get some dummy data.
dummy_data = [
self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=1).multi_modal_data for i in range(num_reqs)
]
multi_modal_kwargs = MultiModalKwargs.batch(dummy_data)
model_kwargs.update(multi_modal_kwargs) return MultiModalKwargs.batch(multi_modal_kwargs_list)
return model_kwargs return {}
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
if self.is_multimodal_raw_input_supported:
mm_budget = self.mm_budget
assert mm_budget is not None
dummy_modality, _ = mm_budget.get_modality_with_max_tokens()
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
return {}
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
self, self,
@ -1517,19 +1513,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# NOTE(woosuk): To unify token ids and soft tokens (vision # NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids) # embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text. # as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens] inputs_embeds_scheduled = self.model.get_input_embeddings(
input_ids=self.input_ids[:num_scheduled_tokens],
model_kwargs = self._init_model_kwargs_for_multimodal_model(
scheduler_output=scheduler_output)
inputs_embeds = self.model.get_input_embeddings(
input_ids=input_ids,
multimodal_embeddings=mm_embeds or None, multimodal_embeddings=mm_embeds or None,
) )
# TODO(woosuk): Avoid the copy. Optimize. # TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) self.inputs_embeds[:num_scheduled_tokens].copy_(
inputs_embeds = self.inputs_embeds[:num_input_tokens] inputs_embeds_scheduled)
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_input_tokens]
model_mm_kwargs = self._extract_mm_kwargs(scheduler_output)
else: else:
# For text-only models, we use token ids as input. # For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the # While it is possible to use embeddings as input just like the
@ -1537,7 +1532,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# then the embedding layer is not included in the CUDA graph. # then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens] input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None inputs_embeds = None
model_kwargs = {} model_mm_kwargs = {}
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens] positions = self.mrope_positions[:, :num_input_tokens]
else: else:
@ -1571,7 +1566,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs( **MultiModalKwargs.as_kwargs(
model_kwargs, model_mm_kwargs,
device=self.device, device=self.device,
), ),
) )
@ -2149,6 +2144,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
yield yield
input_ids.fill_(0) input_ids.fill_(0)
def _get_mm_dummy_batch(
self,
modality: str,
max_items_per_batch: int,
) -> BatchedTensorInputs:
"""Dummy data for profiling and precompiling multimodal models."""
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_counts={modality: 1},
)
dummy_mm_data = dummy_decoder_data.multi_modal_data
# Result in the maximum GPU consumption of the model
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
max_items_per_batch)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
device=self.device,
)
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def _dummy_run(
self, self,
@ -2213,16 +2232,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
model = self.model
if self.is_multimodal_model: if self.is_multimodal_model:
model_kwargs = self._init_model_kwargs_for_multimodal_model(
num_reqs=num_reqs)
input_ids = None input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens] inputs_embeds = self.inputs_embeds[:num_tokens]
model_mm_kwargs = self._dummy_mm_kwargs(num_reqs)
else: else:
input_ids = self.input_ids[:num_tokens] input_ids = self.input_ids[:num_tokens]
inputs_embeds = None inputs_embeds = None
model_kwargs = {} model_mm_kwargs = {}
if self.uses_mrope: if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens] positions = self.mrope_positions[:, :num_tokens]
@ -2247,13 +2264,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config, self.vllm_config,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp): num_tokens_across_dp=num_tokens_across_dp):
outputs = model( outputs = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
**MultiModalKwargs.as_kwargs( **MultiModalKwargs.as_kwargs(
model_kwargs, model_mm_kwargs,
device=self.device, device=self.device,
), ),
) )
@ -2423,62 +2440,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def profile_run(self) -> None: def profile_run(self) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them. if self.is_multimodal_model:
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 mm_budget = self.mm_budget
and self.encoder_cache_size > 0): assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
# it supports multiple. # it supports multiple.
max_tokens_by_modality_dict = self.mm_registry \ (
.get_max_tokens_per_item_by_nonzero_modality(self.model_config) dummy_modality,
dummy_data_modality, max_tokens_per_mm_item = max( max_tokens,
max_tokens_by_modality_dict.items(), key=lambda item: item[1]) ) = mm_budget.get_modality_with_max_tokens()
(
# Check how many items of this modality can be supported by max_mm_items_per_prompt,
# the encoder budget. max_mm_items_per_batch,
encoder_budget = min(self.max_num_encoder_input_tokens, ) = mm_budget.get_max_items(dummy_modality, max_tokens)
self.encoder_cache_size)
max_num_mm_items_encoder_budget = encoder_budget // \
max_tokens_per_mm_item
# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
self.model_config)[dummy_data_modality]
# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
# and chunked prefilled.
max_num_mm_items_decoder_budget = self.max_num_reqs * \
max_mm_items_per_req
max_num_mm_items = max(
1,
min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget))
logger.info( logger.info(
"Encoder cache will be initialized with a budget of %s tokens," "Encoder cache will be initialized with a budget of "
" and profiled with %s %s items of the maximum feature size.", "%s tokens, and profiled with %s %s items of the maximum "
encoder_budget, max_num_mm_items, dummy_data_modality) "feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
# Create dummy batch of multimodal inputs. # Create dummy batch of multimodal inputs.
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( batched_dummy_mm_inputs = self._get_mm_dummy_batch(
model_config=self.model_config, dummy_modality,
seq_len=max_tokens_per_mm_item, max_mm_items_per_batch,
mm_counts={
dummy_data_modality: 1
},
).multi_modal_data
batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items,
pin_memory=self.pin_memory)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
device=self.device,
) )
# Run multimodal encoder. # Run multimodal encoder.
@ -2487,11 +2479,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
dummy_encoder_outputs, dummy_encoder_outputs,
expected_num_items=max_num_mm_items, expected_num_items=max_mm_items_per_batch,
) )
# Cache the dummy encoder outputs. # Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs))
# Add `is_profile` here to pre-allocate communication buffers # Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states \ hidden_states, last_hidden_states \

View File

@ -42,7 +42,6 @@ from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE,
PallasAttentionBackend, PallasAttentionBackend,
PallasMetadata, PallasMetadata,
get_page_size_bytes) get_page_size_bytes)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, KVCacheConfig, KVCacheSpec,
SlidingWindowSpec) SlidingWindowSpec)
@ -55,7 +54,8 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import (
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch
from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing, from .utils import (MultiModalBudget, bind_kv_cache,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs) sanity_check_mm_encoder_outputs)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -195,14 +195,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# TODO: Support M-RoPE (e.g, Qwen2-VL) # TODO: Support M-RoPE (e.g, Qwen2-VL)
assert not self.uses_mrope, "TPU does not support M-RoPE yet." assert not self.uses_mrope, "TPU does not support M-RoPE yet."
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
mm_registry=self.mm_registry,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
self._num_slices_per_kv_cache_update_block = \ self._num_slices_per_kv_cache_update_block = \
_get_num_slices_per_kv_cache_update_block(get_page_size_bytes( _get_num_slices_per_kv_cache_update_block(get_page_size_bytes(
block_size=self.block_size, block_size=self.block_size,
@ -294,36 +286,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.structured_decode_arange = torch.arange( self.structured_decode_arange = torch.arange(
0, 32, device="cpu", pin_memory=self.pin_memory) 0, 32, device="cpu", pin_memory=self.pin_memory)
# Get maximum number of mm items per modality (batch size). self.mm_budget = (MultiModalBudget(
self.max_num_mm_items_by_modality = dict() self.model_config,
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 self.scheduler_config,
and self.encoder_cache_size > 0): self.mm_registry,
max_tokens_by_modality_dict = ( max_model_len=self.max_model_len,
MULTIMODAL_REGISTRY. max_num_reqs=self.max_num_reqs,
get_max_tokens_per_item_by_nonzero_modality(self.model_config)) ) if self.is_multimodal_model else None)
for modality, max_tokens in max_tokens_by_modality_dict.items():
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
max_tokens)
# Check how many items of this modality can be supported by
# the decoder budget.
max_mm_items_per_req = self.mm_registry.\
get_mm_limits_per_prompt(self.model_config)[modality]
# NOTE: We do not consider max_num_batched_tokens on purpose
# because the multimodal embeddings can be generated in advance
# and chunked prefilled.
max_num_mm_items_decoder_budget = self.max_num_reqs * \
max_mm_items_per_req
max_num_mm_items = min(max_num_mm_items_encoder_budget,
max_num_mm_items_decoder_budget)
self.max_num_mm_items_by_modality[modality] = max_num_mm_items
if not self.use_spmd: if not self.use_spmd:
self.sample_from_logits_func = torch.compile( self.sample_from_logits_func = torch.compile(
@ -1335,23 +1304,33 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
xm.mark_step() # Captures metadata updates xm.mark_step() # Captures metadata updates
def _precompile_mm_encoder(self) -> None: def _precompile_mm_encoder(self) -> None:
if not self.is_multimodal_model:
return
# Pre-compile MM encoder for all supported data modalities. # Pre-compile MM encoder for all supported data modalities.
hf_config = self.vllm_config.model_config.hf_config hf_config = self.vllm_config.model_config.hf_config
for mode, max_items_by_mode in \
self.max_num_mm_items_by_modality.items(): mm_budget = self.mm_budget
assert mm_budget is not None
max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality # noqa: E501
for mode, max_items_per_seq in max_items_per_seq_by_modality.items():
logger.info( logger.info(
"Compiling Multimodal %s Encoder with different input" "Compiling Multimodal %s Encoder with different input"
" shapes.", mode) " shapes.", mode)
start = time.perf_counter() start = time.perf_counter()
# No padding for MM encoder just yet. # No padding for MM encoder just yet.
for num_items in range(1, max_items_by_mode + 1): for num_items in range(1, max_items_per_seq + 1):
logger.info(" -- mode: %s items: %d", mode, num_items) logger.info(" -- mode: %s items: %d", mode, num_items)
batched_dummy_mm_inputs = self._get_mm_dummy_batch( batched_dummy_mm_inputs = self._get_mm_dummy_batch(
mode, num_items) mode,
num_items,
)
# Run multimodal encoder. # Run multimodal encoder.
xm.mark_step() xm.mark_step()
mm_embeds = self.model.\ mm_embeds = self.model.get_multimodal_embeddings(
get_multimodal_embeddings(**batched_dummy_mm_inputs) **batched_dummy_mm_inputs)
xm.mark_step() xm.mark_step()
num_patches = mm_embeds[0].shape[0] num_patches = mm_embeds[0].shape[0]
items_size = num_patches * num_items items_size = num_patches * num_items
@ -1547,27 +1526,38 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens: int, num_tokens: int,
) -> None: ) -> None:
# Profile with multimodal encoder & encoder cache. # Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them. if self.is_multimodal_model:
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0 mm_budget = self.mm_budget
and self.encoder_cache_size > 0): assert mm_budget is not None
# TODO: handle encoder-decoder models once we support them.
if (encoder_budget := mm_budget.get_encoder_budget()) > 0:
# NOTE: Currently model is profiled with a single non-text # NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
# it supports multiple. # it supports multiple.
dummy_data_modality, max_num_mm_items = max( (
self.max_num_mm_items_by_modality.items(), key=lambda t: t[1]) dummy_modality,
max_tokens,
encoder_budget = min(self.max_num_encoder_input_tokens, ) = mm_budget.get_modality_with_max_tokens()
self.encoder_cache_size) (
max_mm_items_per_prompt,
max_mm_items_per_batch,
) = mm_budget.get_max_items(dummy_modality, max_tokens)
logger.info( logger.info(
"Encoder cache will be initialized with a budget of %d tokens," "Encoder cache will be initialized with a budget of "
" and profiled with %s %s items of the maximum feature size.", "%s tokens, and profiled with %s %s items of the maximum "
encoder_budget, max_num_mm_items, dummy_data_modality) "feature size.",
encoder_budget,
max_mm_items_per_batch,
dummy_modality,
)
# Create dummy batch of multimodal inputs. # Create dummy batch of multimodal inputs.
batched_dummy_mm_inputs = self._get_mm_dummy_batch( batched_dummy_mm_inputs = self._get_mm_dummy_batch(
dummy_data_modality, max_num_mm_items) dummy_modality,
max_mm_items_per_batch,
)
# Run multimodal encoder. # Run multimodal encoder.
# Isolate encoder graph from post-processing to minimize # Isolate encoder graph from post-processing to minimize
@ -1583,15 +1573,14 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"Multimodal Encoder profiling finished in in %.2f [secs].", "Multimodal Encoder profiling finished in in %.2f [secs].",
end - start) end - start)
assert len(dummy_encoder_outputs) == max_num_mm_items, ( sanity_check_mm_encoder_outputs(
"Expected dimension 0 of encoder outputs to match the number " dummy_encoder_outputs,
f"of multimodal data items: {max_num_mm_items}, got " expected_num_items=max_mm_items_per_batch,
f"{len(dummy_encoder_outputs)=} instead. This is most likely " )
"due to the 'get_multimodal_embeddings' method of the model "
"not implemented correctly.")
# Cache the dummy encoder outputs. # Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) self.encoder_cache["tmp"] = dict(
enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape. # Trigger compilation for general shape.
self._dummy_run(num_tokens, self.num_reqs_max_model_len, self._dummy_run(num_tokens, self.num_reqs_max_model_len,
@ -1809,33 +1798,25 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \ self.grammar_bitmask_cpu[:num_reqs].to(logits.device), \
self.structured_decode_arange.to(logits.device) self.structured_decode_arange.to(logits.device)
def _get_mm_dummy_batch(self, modality: str, def _get_mm_dummy_batch(
batch_size: int) -> BatchedTensorInputs: self,
# Dummy data for pre-compiling multimodal models. modality: str,
dummy_request_data = self.mm_registry.get_decoder_dummy_data( max_items_per_batch: int,
) -> BatchedTensorInputs:
"""Dummy data for profiling and precompiling multimodal models."""
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config, model_config=self.model_config,
seq_len=self.max_num_tokens, seq_len=self.max_num_tokens,
mm_counts={modality: 1},
) )
dummy_mm_data = dummy_request_data.multi_modal_data dummy_mm_data = dummy_decoder_data.multi_modal_data
# Dummy data definition in V0 may contain multiple multimodal items # Result in the maximum GPU consumption of the model
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.
assert isinstance(dummy_mm_data, MultiModalKwargs), (
"Expected dummy multimodal data to be of type "
f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. "
"This is most likely due to the model not having a merged "
"processor.")
# When models have a merged processor, their dummy data is
# already batched `MultiModalKwargs`, therefore we take the first
# `MultiModalKwargsItem` from the desired modality to profile on.
dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0) dummy_mm_item = dummy_mm_data.get_item(modality=modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item]) dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] * batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
batch_size) max_items_per_batch)
return MultiModalKwargs.as_kwargs( return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, batched_dummy_mm_inputs,
device=self.device, device=self.device,

View File

@ -5,14 +5,123 @@ from typing import TYPE_CHECKING, Optional
import torch import torch
from vllm.config import ModelConfig, SchedulerConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.registry import MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
model_config: ModelConfig,
scheduler_config: SchedulerConfig,
mm_registry: MultiModalRegistry,
*,
max_model_len: int,
max_num_reqs: int,
) -> None:
super().__init__()
self.model_config = model_config
self.scheduler_config = scheduler_config
self.mm_registry = mm_registry
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
mm_registry=mm_registry,
)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
self.max_model_len = max_model_len
self.max_num_reqs = max_num_reqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
max_items_per_prompt_by_modality = dict[str, int]()
max_items_per_batch_by_modality = dict[str, int]()
max_tokens_by_modality = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
for modality, max_tokens in max_tokens_by_modality.items():
(
max_items_per_prompt,
max_items_per_batch,
) = self.get_max_items(modality, max_tokens)
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
max_items_per_batch_by_modality[modality] = max_items_per_batch
self.max_tokens_by_modality = max_tokens_by_modality
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
def get_modality_with_max_tokens(self) -> tuple[str, int]:
max_tokens_by_modality = self.max_tokens_by_modality
modality, max_tokens = max(max_tokens_by_modality.items(),
key=lambda item: item[1])
return modality, max_tokens
def get_encoder_budget(self) -> int:
return min(self.max_num_encoder_input_tokens, self.encoder_cache_size)
def get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = self.get_encoder_budget()
# TODO: handle encoder-decoder models once we support them.
if encoder_budget == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
def sanity_check_mm_encoder_outputs( def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings, mm_embeddings: MultiModalEmbeddings,
expected_num_items: int, expected_num_items: int,